Merge branch 'ck_tile/refactor' into ck_tile/elementwise

This commit is contained in:
rocking
2024-04-09 02:37:42 +08:00
committed by GitHub
216 changed files with 9200 additions and 4325 deletions

View File

@@ -145,6 +145,22 @@ if(GPU_TARGETS)
else()
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
endif()
if (GPU_TARGETS)
if (GPU_TARGETS MATCHES "gfx9")
add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif()
if (GPU_TARGETS MATCHES "gfx11")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif()
else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON")
set(CK_USE_WMMA "ON")
endif()
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213

View File

@@ -1,27 +1,29 @@
add_custom_target(client_gemm_fastgelu_examples)
if(GPU_TARGETS MATCHES "gfx9")
add_custom_target(client_gemm_fastgelu_examples)
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
client_gemm_fastgelu)
add_custom_target(client_gemm_fastgelu_generic_examples)
add_custom_target(client_gemm_fastgelu_generic_examples)
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic)
endif()

View File

@@ -1,5 +1,7 @@
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
endif()

View File

@@ -1,15 +1,16 @@
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,5 +1,7 @@
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
endif()

View File

@@ -1,5 +1,7 @@
add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,22 +1,22 @@
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_gemm_quantization gemm_quantization.cpp)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_gemm_quantization gemm_quantization.cpp)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,5 +1,7 @@
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations)
endif()

View File

@@ -1,3 +1,4 @@
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -17,6 +17,11 @@ if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
@@ -95,7 +95,8 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout,
typename OutLayout,
ck::index_t NumNonSpatialDim = 3,
typename ComputeType = InDataType>
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
@@ -186,7 +187,8 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
PassThrough,
PassThrough,
PassThrough,
ComputeType>;
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::bf8_t;
using OutDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using AComputeType = ck::f8_t;
using BComputeType = ck::bf8_t;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeType,
BComputeType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}

View File

@@ -1,2 +1,4 @@
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,4 +1,4 @@
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES))
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,2 +1,4 @@
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,11 +1,13 @@
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -1,3 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
# Fwd scaleadd scaleadd relu
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32
grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp)
@@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke
add_executable(client_grouped_convnd_bwd_data_scale_fp16
grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp)
target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations)
endif()

View File

@@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)

View File

@@ -48,7 +48,10 @@ else()
endif()
endif()
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations)
if(GPU_TARGETS MATCHES "gfx9")
find_package(composable_kernel COMPONENTS device_contraction_operations)
endif()
find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}")

View File

@@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
@@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this example as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
@@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)

View File

@@ -1,20 +1,3 @@
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
set(target 1)
endif()
endforeach()
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)

View File

@@ -1,8 +1 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)

View File

@@ -1,29 +1,20 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
set(target 1)
endif()
endforeach()
set(gpu_list "")
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)

View File

@@ -1,19 +1,11 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::bf8_t;
using AccDataType = float;
using CShuffleDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using AComputeType = ck::f8_t;
using BComputeType = ck::bf8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeType,
BComputeType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }

View File

@@ -1,25 +1,17 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl)
add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)

View File

@@ -1,12 +1,3 @@
# dlops
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
# xdlops
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)

View File

@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)

View File

@@ -0,0 +1,394 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <ck/utility/data_type.hpp>
#include <ck/utility/tuple.hpp>
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType, DDataType>;
using EDataType = F32;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout, DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr int NumDMatrices = 2;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<std::vector<ck::index_t>> stride_Ds;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 128;
bool time_kernel = true;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
std::vector<std::array<const void*, NumDMatrices>> p_Ds = {};
gemm_descs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
p_Ds.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<std::array<Tensor<DDataType>, NumDMatrices>> d_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_result_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_result_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDMatrices> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens);
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices +
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
}
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
problem_size.stride_Ds[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
gemm.SetKBatchSize(argument, config.k_batch);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = argument.gemm_kernel_args_[i].karg_;
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{}));
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(),
c_device_result_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
d_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
return pass;
}
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc < 11)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg10: k_batch (> 0)\n"
<< "... setting default values." << std::endl;
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[10]);
problem_size.Ms = argToIntArray(argv[4]);
problem_size.Ns = argToIntArray(argv[5]);
problem_size.Ks = argToIntArray(argv[6]);
problem_size.stride_As = argToIntArray(argv[7]);
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
problem_size.group_count = problem_size.Ms.size();
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -36,7 +36,7 @@ using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128);
problem_size.Ms.push_back(128 + rand() % 128);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(1024);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);

View File

@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F8;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
struct ProblemSize final

View File

@@ -1,48 +1,41 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
add_example_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
add_example_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
set(target 1)
endif()
endforeach()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()

View File

@@ -1,14 +1,7 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
endif()
set(target 1)
endif()
endforeach()
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
endif()
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
if(result EQUAL 0)

View File

@@ -1,29 +1,15 @@
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
set(target 1)
endif()
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
set(target 1)
endif()
endforeach()
add_custom_target(example_grouped_conv_bwd_weight_dl)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)

View File

@@ -1,12 +1,4 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)

View File

@@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear)
# FP32
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
# FP64
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
# FP16
add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
# BF16
add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
add_dependencies(example_contraction example_contraction_scale)
add_dependencies(example_contraction example_contraction_bilinear)
add_example_dependencies(example_contraction example_contraction_scale)
add_example_dependencies(example_contraction example_contraction_bilinear)

View File

@@ -1,5 +1,2 @@
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)

View File

@@ -1,40 +1,23 @@
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() # USE_BITINT_EXTENSION_INT4
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() # USE_BITINT_EXTENSION_INT4
set(target 1)
endif()
endforeach()
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)

View File

@@ -1,17 +1,9 @@
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)

View File

@@ -1,11 +1,9 @@
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
endif()
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)

View File

@@ -1,32 +1,23 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_splitK_gemm_xdl)
add_custom_target(example_splitK_gemm_xdl)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()
set(target 1)
endif()
endforeach()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()

View File

@@ -1,27 +1,10 @@
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data)
add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
set(target 1)
endif()
endforeach()
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
set(target 1)
endif()
endforeach()
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)

View File

@@ -1,24 +1,17 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
# Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
# Conv + bias + relu perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
# Conv + bias + relu perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
# Conv + bias + tanh perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
# Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
# Conv + bias + relu perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
# Conv + bias + relu perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
# Conv + bias + tanh perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)

View File

@@ -1,17 +1,9 @@
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list2 gfx908 gfx90a)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)

View File

@@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu
add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp)
add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp)
add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp)
add_example_executable(example_elementwise_permute elementwise_permute.cpp)
if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942"))
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp)

View File

@@ -0,0 +1,140 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using BinaryAdd = ck::tensor_operation::element_wise::Add;
// B = alpha * A0 * A0 + beta * A1 * A1
using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise::
BinaryWithUnaryCombinedOp<BinaryAdd, UnaryScaleSquare, UnaryScaleSquare>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType, ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
BinaryAddUnaryScaleSquare, // ElementwiseOp
4, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8, 8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
static_cast<int>(nchw[2] * nchw[3]),
static_cast<int>(nchw[3]),
1};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 2> as = {Tensor<ADataType>(ab_lengths, ab_strides),
Tensor<ADataType>(ab_lengths, ab_strides)};
Tensor<ADataType>& a0 = as[0];
Tensor<ADataType>& a1 = as[1];
Tensor<BDataType> b(ab_lengths, ab_strides);
float alpha = 3.f;
float beta = 2.f;
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0.mData.data());
a1_device_buf.ToDevice(a1.mData.data());
std::array<const void*, 2> inputs = {a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths,
{ab_strides, ab_strides},
{ab_strides},
inputs,
output,
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
std::cout << "B (nchw): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
Tensor<BDataType> host_b(ab_lengths, ab_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as,
host_b,
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance =
ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
{
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{
auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_ndhwc(n, d, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -51,32 +39,7 @@ int main()
std::vector<std::size_t> ncdhw = {16, 8, 8, 8, 8};
std::vector<std::size_t> ndhwc = {16, 8, 8, 8, 8};
Tensor<ADataType> a(ncdhw);
Tensor<BDataType> b(ndhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths;
/**std::array<ck::index_t, 5> a_strides = {
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[2] * ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[4]),
1};
std::array<ck::index_t, 5> b_strides = {
static_cast<int>(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]),
static_cast<int>(ndhwc[2] * ndhwc[3] * ndhwc[4]),
1,
static_cast<int>(ndhwc[3] * ndhwc[4]),
static_cast<int>(ndhwc[4])};**/
std::array<ck::index_t, 5> a_strides = {
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
@@ -93,6 +56,20 @@ int main()
1};
ck::ranges::copy(ncdhw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -126,10 +103,16 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(ndhwc);
host_elementwise4D(host_b, a, PassThrough{});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance =
ck::Sequence<4>, // InScalarPerVectorSeq
ck::Sequence<4>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
{
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{
auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_ndhwc(n, d, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -59,10 +47,13 @@ int main()
const int W = 5;
const int D = 16;
std::vector<std::size_t> ncdhw = {N, C, D, H, W};
std::vector<std::size_t> ndhwc = {N, D, H, W, C};
Tensor<ADataType> a(ncdhw);
Tensor<BDataType> b(ndhwc);
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
@@ -74,10 +65,6 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -94,11 +81,12 @@ int main()
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] *
ab_lengths[3] * ab_lengths[4];
std::size_t num_btype =
sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
(sizeof(ADataType) + sizeof(BDataType)) *
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -111,10 +99,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(ndhwc);
host_elementwise4D(host_b, a, PassThrough{});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -55,18 +44,6 @@ int main()
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
@@ -77,9 +54,22 @@ int main()
1,
static_cast<int>(nhwc[2] * nhwc[3]),
static_cast<int>(nhwc[3])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -111,10 +101,16 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
@@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance =
ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
const std::vector<std::size_t>& shape_nchw,
Functor functor)
{
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -54,13 +40,16 @@ int main()
const int N = 120;
const int C = 128;
const int H = 32;
const int W = 1024;
const int W = 32;
std::vector<std::size_t> nchw = {N, C, H, W};
std::vector<std::size_t> nhwc = {N, H, W, C};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
@@ -72,11 +61,6 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -94,10 +78,11 @@ int main()
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t flop =
std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) *
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -110,11 +95,16 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{});
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -6,9 +6,11 @@
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -21,11 +23,14 @@ using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1];
std::size_t H = A_nchw.mDesc.GetLengths()[2];
std::size_t W = A_nchw.mDesc.GetLengths()[3];
for(std::size_t w = 0; w < W; ++w)
for(std::size_t h = 0; h < H; ++h)
for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n)
{
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
}
}
int main()
{
bool do_verification = true;
@@ -60,8 +48,21 @@ int main()
std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 1.f;
auto i = 0;
std::mt19937 gen(11939);
@@ -84,22 +85,14 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -113,11 +106,10 @@ int main()
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype =
(2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
@@ -129,10 +121,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -5,9 +5,11 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -20,11 +22,14 @@ using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -55,18 +47,6 @@ int main()
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
@@ -80,9 +60,29 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -112,10 +112,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -5,9 +5,11 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -20,11 +22,14 @@ using F32 = float;
using ADataType = F32;
using BDataType = F32;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1];
std::size_t H = A_nchw.mDesc.GetLengths()[2];
std::size_t W = A_nchw.mDesc.GetLengths()[3];
for(std::size_t w = 0; w < W; ++w)
for(std::size_t h = 0; h < H; ++h)
for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n)
{
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {5, 4, 2, 3};
std::vector<std::size_t> nhwc = {5, 2, 3, 4};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 1.f;
auto i = 0;
@@ -84,22 +86,14 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -129,10 +123,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -5,9 +5,11 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -20,11 +22,14 @@ using F32 = float;
using ADataType = F32;
using BDataType = F32;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -55,18 +47,6 @@ int main()
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
@@ -80,9 +60,28 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -112,10 +111,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using BinaryAdd = ck::tensor_operation::element_wise::Add;
// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2
using TrinaryAddUnaryScaleSquare =
ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp<BinaryAdd,
BinaryAdd,
UnaryScaleSquare,
UnaryScaleSquare,
UnaryScaleSquare>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType, ADataType, ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
TrinaryAddUnaryScaleSquare, // ElementwiseOp
4, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
static_cast<int>(nchw[2] * nchw[3]),
static_cast<int>(nchw[3]),
1};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 3> as = {Tensor<ADataType>(ab_lengths, ab_strides),
Tensor<ADataType>(ab_lengths, ab_strides),
Tensor<ADataType>(ab_lengths, ab_strides)};
Tensor<ADataType>& a0 = as[0];
Tensor<ADataType>& a1 = as[1];
Tensor<ADataType>& a2 = as[2];
Tensor<BDataType> b(ab_lengths, ab_strides);
float alpha = 3.f;
float beta = 2.f;
float gamma = 4.f;
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a2.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0.mData.data());
a1_device_buf.ToDevice(a1.mData.data());
a2_device_buf.ToDevice(a2.mData.data());
std::array<const void*, 3> inputs = {a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
a2_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths,
{ab_strides, ab_strides, ab_strides},
{ab_strides},
inputs,
output,
TrinaryAddUnaryScaleSquare{
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
std::cout << "A2 (nchw): " << a2.mDesc << std::endl;
std::cout << "B (nchw): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
Tensor<BDataType> host_b(ab_lengths, ab_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as,
host_b,
TrinaryAddUnaryScaleSquare{
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
ref_invoker.Run(ref_argument);
const double threshold = std::pow(2, -10) * 2;
b_device_buf.FromDevice(b.mData.data());
pass &= ck::utils::check_err(
b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold);
}
return pass ? 0 : 1;
}

View File

@@ -1,8 +1 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp)

View File

@@ -1,15 +1,7 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_im2col_col2im)
add_custom_target(example_im2col_col2im)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
set(target 1)
endif()
endforeach()
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)

View File

@@ -1,8 +1 @@
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)

View File

@@ -1,8 +1 @@
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)

View File

@@ -2,16 +2,9 @@ add_subdirectory(binary)
add_subdirectory(multi_AB)
add_subdirectory(unary)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_activ_xdl)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
set(target 1)
endif()
endforeach()
add_custom_target(example_convnd_activ_xdl)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)

View File

@@ -1,5 +1,3 @@
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
endif()
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)

View File

@@ -5,6 +5,12 @@ include_directories(BEFORE
add_custom_target(examples)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(FILE_NAME)
add_dependencies(EXAMPLE_NAME FILE_NAME)
endif()
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
set(result 1)
@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif()
endforeach()
endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif()
endforeach()
endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})

View File

@@ -56,12 +56,13 @@ auto create_args(int argc, char* argv[])
.insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left local-attn with left right size\n"
"'b:l,r', bottom-r local-attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size\n")
.insert(
"mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left sliding window attn with left right size\n"
"'b:l,r', bottom-r sliding window attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
@@ -357,44 +358,45 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
return fmha_fwd_args<FmhaDefaultElementFunctions>{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.y,
mask.x,
ck_tile::identity{},
ck_tile::identity{}};
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
ck_tile::identity{},
ck_tile::identity{}};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
@@ -508,12 +510,32 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
}
else
{
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
}
if(lse)
{

View File

@@ -104,169 +104,6 @@ struct FmhaMasks
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
#if 0
// internal API, don't use this directly
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t nhead_k,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t max_seqlen_q,
float scale,
bool i_perm,
bool o_perm,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x)
{
constexpr bool is_v_rowmajor =
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias'
/// are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if constexpr(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_k : nhead_k * seqlen_k;
}();
const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if constexpr(is_v_rowmajor)
return i_perm ? seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_k : seqlen_k;
}();
const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k);
const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1);
const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1);
const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
nhead / nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
mask_y,
mask_x);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
nhead / nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask_y,
mask_x);
}
}();
dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// This is the args from caller to underneath API, different from the kernel
struct fmha_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t batch;
ck_tile::index_t nhead;
ck_tile::index_t nhead_k;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t max_seqlen_q;
float scale;
bool i_perm;
bool o_perm;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
};
#endif
// runtime args, some will passed to karg, some will used to compute grids/blocks
template <typename ElementFunctions>
struct fmha_fwd_args
@@ -306,8 +143,9 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// typename ElementFunctions::QElementFunction q_element_func;
// typename ElementFunctions::KElementFunction k_element_func;
// typename ElementFunctions::VElementFunction v_element_func;
@@ -350,8 +188,9 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.window_size_left,
args.window_size_right,
args.mask_type,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
@@ -392,8 +231,9 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.batch_stride_bias,
args.batch_stride_lse,
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.window_size_left,
args.window_size_right,
args.mask_type,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
@@ -420,6 +260,7 @@ template <ck_tile::index_t HDim_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
bool kHasBias_,
bool kStoreLse_,
@@ -439,6 +280,7 @@ struct fmha_fwd_traits_
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;

View File

@@ -24,6 +24,16 @@ DTYPE_BITS = {
"bf8" : 8
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
@@ -49,13 +59,16 @@ ELEMENT_FUNC_MAP = {
"no" : "FmhaDefaultElementFunctions",
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
MASKS = ["no", "causal", "generic"]
DIRECTIONS = ["fwd"]
GEN_DIR = "" # in Cmake, have to generate files in same folder
@@ -130,7 +143,8 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
@@ -168,17 +182,40 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
"""
MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
def get_mask_map(mask : str):
if mask == "generic":
return MASK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_check_map(mask : str):
if mask == "generic":
return MASK_CHECK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
@@ -212,14 +249,19 @@ class FmhaFwdApiTrait:
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
@@ -228,7 +270,7 @@ class FmhaFwdApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@@ -239,7 +281,7 @@ class FmhaFwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@@ -270,13 +312,17 @@ class FmhaFwdPipeline:
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias == 't' : n += '_bias'
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
return n
class FmhaFwdApiPool:
def __init__(self):
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
@@ -298,8 +344,9 @@ class FmhaFwdApiPool:
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask],
F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
@@ -341,6 +388,7 @@ class FmhaFwdKernel:
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
F_element_func : str
@property
@@ -369,8 +417,9 @@ class FmhaFwdKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy ,
F_mask = MASK_MAP[self.F_pipeline.F_mask],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
@@ -426,14 +475,17 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask))
@@ -446,16 +498,19 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask))
if receipt == 1:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool()
api_pool = FmhaFwdApiPool(mask_impl)
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
@@ -474,6 +529,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_element_func=element_func)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
@@ -489,24 +545,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None:
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
api_pool, kernels = get_blobs(kernel_filter)
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_api(api_pool, output_dir)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None:
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
with file_path.open('a') as f:
_, kernels = get_blobs(kernel_filter)
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
@@ -535,8 +591,26 @@ if __name__ == "__main__":
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim"
)
args = parser.parse_args()
if args.list_blobs is not None:
list_blobs(args.list_blobs, args.filter)
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
else:
write_blobs(args.output_dir, args.filter)
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)

View File

@@ -9,11 +9,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
causal_top_left,
causal_bottom_right,
mask_top_left,
mask_bottom_right,
window_generic,
};
@@ -21,18 +22,19 @@ struct mask_info
{
mask_enum type;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::causal_top_left)
os << "tl";
else if(type == mask_enum::causal_bottom_right)
os << "br";
else if(type == mask_enum::mask_top_left)
os << "tl(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "br(" << left << ":" << right << ")";
else
{
os << "g(" << y << "/" << x << ")";
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
@@ -57,22 +59,30 @@ struct mask_info
// TODO: some validation
if(t == "t")
{
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
@@ -84,15 +94,19 @@ struct mask_info
{
// should be 0, 1, 2
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::causal_top_left)
if(tmp.type == mask_enum::mask_top_left)
{
tmp.y = seqlen_q;
tmp.x = 1;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(tmp.type == mask_enum::causal_bottom_right)
else if(tmp.type == mask_enum::mask_bottom_right)
{
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
}
return tmp;

View File

@@ -23,7 +23,8 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done
done

View File

@@ -45,6 +45,10 @@
#endif
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
@@ -62,8 +66,7 @@
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__)
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
@@ -75,8 +78,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
defined(__gfx94__) // for GPU code
#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
@@ -89,7 +91,7 @@
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
#elif defined(__gfx9__) // for GPU code
#define CK_USE_AMD_MFMA
#endif
@@ -120,7 +122,7 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
#elif defined(__gfx9__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
*/
template <index_t NDimSpatial,
typename ALayout,
@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeType =
typename AComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>())> // ComputeType is InputType by default (first
ADataType>()), // AComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeType = AComputeType>
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
{
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;

View File

@@ -0,0 +1,136 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmMultipleDKernelArguments
{
__host__ __device__
GroupedGemmMultipleDKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -22,10 +22,12 @@ namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
index_t NumDim, // The max dim of input tensors
// the tensors descs have to be aligned, such that
// the innermost dim is the contiguous one.
index_t MPerThread, // How many elements per thread to read
typename InScalarPerVectorSeq, // Scalar per vec for each Input
typename OutScalarPerVectorSeq> // Scalar per vec for each Output
struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{
@@ -242,13 +244,13 @@ struct DeviceElementwiseImpl
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false;
valid = valid && false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false;
valid = valid && false;
});
return valid;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -254,13 +254,14 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType =
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler LoopSched = make_default_loop_scheduler()>
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputeDataType>
AComputeDataType,
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
// Use appropriate gridwise gemm
using GridwiseGemm =
std::conditional_t<isMultiA || isMultiB,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -75,13 +75,14 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType =
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler LoopSched = make_default_loop_scheduler()>
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
ALayout,
@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
ComputeDataType,
AComputeDataType,
BComputeDataType,
LoopSched>;
} // namespace device

View File

@@ -0,0 +1,987 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/common_header.hpp"
#include <ck/utility/loop_scheduler.hpp>
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = EDataType,
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
enable_if_t<AK1 == BK1, bool> = false>
struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
: public DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1
static constexpr index_t K0PerBlock = KPerBlock / AK1;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using WorkspaceDataType = float;
// First stage GridwiseGEMM kernel.
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType,
BDataType,
AccDataType,
WorkspaceDataType,
ALayout,
BLayout,
ELayout,
AElementwiseOperation,
BElementwiseOperation,
PassThrough, // CElementwiseOperation
GemmSpec,
NumGemmKPrefetchStage,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
AK1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVer,
ComputeDataType>;
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
static constexpr auto MakeElementwiseInputSequence()
{
return generate_sequence_v2(
[&]([[maybe_unused]] auto i) constexpr {
return Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{};
},
Number<NumDTensor + 1>{});
}
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using DsGridPointer = decltype(MakeDsGridPointer());
using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
using CDDataTypes = decltype(concat_tuple(ck::Tuple<WorkspaceDataType*>{}, DsGridPointer{}));
using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence());
static constexpr index_t ClusterLengthMPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise =
GridwiseElementwise<CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
Block2TileMap,
CDEElementwiseOperation,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<0, 1>,
ElementwiseInputSequence,
ck::Sequence<CDEShuffleBlockTransferScalarPerVector_NPerBlock>,
true>;
// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using GemmKernelArgument = typename GridwiseGemm::Argument;
struct GemmTransKernelArg
{
GemmKernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
GemmTransKernelArg(GemmKernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
block_2_ctile_map_{b2c_map},
block_start_{block_start},
block_end_{block_end}
{
}
};
static constexpr index_t DefaultKBatch = 1;
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: Argument(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_element_op,
b_element_op,
cde_element_op,
DefaultKBatch)
{
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t kbatch)
: K_BATCH{kbatch},
group_count_{0},
skipped_group_count_{0},
grid_size_{0},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
p_Ds_{p_Ds}
{
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size");
}
gemm_kernel_args_.reserve(group_count_);
elementwise_c_grid_descs_m_n_.reserve(group_count_);
elementwise_d_grid_descs_m_n_.reserve(group_count_);
ds_grid_pointer_.reserve(group_count_);
group_grid_size_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M * N * K == 0)
{
skipped_group_count_++;
continue;
}
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_e = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e);
DsGridDesc_M_N ds_grid_desc_m_n;
DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
});
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
group_grid_size_[i] = grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
std::array<index_t, NumDTensor> stride_ds;
static_for<0, NumDTensor, 1>{}([&](auto j) {
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
stride_ds[j] = gemm_descs[i].stride_Ds_[j];
});
stride_Ds_.emplace_back(std::move(stride_ds));
// We first set E pointer to actual operation output, but later on
// when workspace will be set, this will be updated to workspace memory.
auto karg = GemmKernelArgument{type_convert<const ADataType*>(p_As[i]),
type_convert<const BDataType*>(p_Bs[i]),
type_convert<WorkspaceDataType*>(p_Es[i]),
M,
N,
K,
stride_a,
stride_b,
stride_e,
m_padded,
n_padded,
k_padded,
k0_padded,
K_BATCH};
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
ds_grid_pointer_.push_back(p_ds_grid);
}
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_ = p_Es;
}
/**
* @brief Set new kbatch value.
*
* @param[in] kbatch The new splitK parameter value.
*/
void UpdateKBatch(index_t kbatch)
{
K_BATCH = kbatch;
grid_size_ = 0;
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
group_grid_size_[i] = grid_size_grp;
karg.KPadded = k_padded;
karg.K0Padded = k0_padded;
karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
#if DEBUG_LOG
index_t tiles = (block_end - block_start) / K_BATCH;
std::cout << "block_start: " << block_start << "\n"
<< "block_end: " << block_end << "\n"
<< "tiles: " << tiles << std::endl
<< std::endl;
std::cout << "KPadded: " << karg.KPadded << std::endl
<< "K0Padded: " << karg.K0Padded << std::endl
<< "KBatch: " << karg.k_batch << std::endl
<< "grid_size_: " << karg.KPadded << std::endl;
#endif
}
}
void UpdateEPointers()
{
// set-up each group E pointer to it's designated workspace memory.
WorkspaceDataType* p_workspace = reinterpret_cast<WorkspaceDataType*>(p_workspace_);
std::size_t offset = 0;
for(auto& arg : gemm_kernel_args_)
{
arg.karg_.p_c_grid = p_workspace + offset;
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
offset += tiles * MPerBlock * NPerBlock;
#if DEBUG_LOG
std::cout << "block_start: " << arg.block_start_ << "\n"
<< "block_end: " << arg.block_end_ << "\n"
<< "tiles: " << tiles << "\n"
<< "offset: " << offset << std::endl;
#endif
}
}
std::size_t GetWorkspaceSizeBytes() const
{
std::size_t size_bytes{0};
for(const auto& arg : gemm_kernel_args_)
{
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType);
}
return size_bytes;
}
std::size_t GetWorkspaceSize(std::size_t group) const
{
const auto& arg = gemm_kernel_args_[group];
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
return tiles * MPerBlock * NPerBlock;
}
// private:
index_t K_BATCH;
index_t group_count_;
index_t skipped_group_count_;
index_t grid_size_;
// Pointer to device memory with GEMM kernel arguments.
const void* p_dev_gemm_args_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
std::vector<std::array<const void*, NumDTensor>>& p_Ds_;
std::vector<std::array<index_t, NumDTensor>> stride_Ds_;
std::vector<GemmTransKernelArg> gemm_kernel_args_;
std::vector<index_t> group_grid_size_;
std::vector<CGridDesc_M_N> elementwise_c_grid_descs_m_n_;
std::vector<DsGridDesc_M_N> elementwise_d_grid_descs_m_n_;
std::vector<DsGridPointer> ds_grid_pointer_;
std::vector<void*> e_ptrs_;
};
// Invoker
struct Invoker : public BaseInvoker
{
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host
/// memory).
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
/// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary
/// workspace.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(dev_gemm_workspace == nullptr)
{
std::ostringstream err;
err << "The gemm workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time = 0;
if(all_have_main_k_block_loop)
{
ave_time =
DispatchKernel<true>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
else
{
ave_time =
DispatchKernel<false>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
return ave_time;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device buffers (for kernel arguments and
/// for kernel auxiliary workspace) provided with an argument. The user should
/// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see
/// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly
/// allocate those buffers.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_dev_gemm_args_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(arg.p_workspace_ == nullptr)
{
std::ostringstream err;
err << "The gemm workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{
bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(
arg.gemm_kernel_args_[0].karg_.M,
arg.gemm_kernel_args_[0].karg_.MPadded,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideA,
arg.gemm_kernel_args_[0].karg_.k_batch,
arg.gemm_kernel_args_[0].karg_.K0Padded,
arg.gemm_kernel_args_[0].karg_.KPadded);
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
}
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
if(stream_config.log_level_ > 0)
{
gemm_arg.Print();
}
if(!GridwiseGemm::CheckValidity(gemm_arg))
{
std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
gemm_arg.MPadded,
gemm_arg.K,
gemm_arg.StrideA,
gemm_arg.k_batch,
gemm_arg.K0Padded,
gemm_arg.KPadded);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
bool not_all_have_kbatch_value_same =
all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << gemm_arg.k_batch
<< ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
}
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
const auto gemm_kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
HasMainKBlockLoop,
InMemoryDataOperationEnum::AtomicAdd,
AElementwiseOperation,
BElementwiseOperation,
PassThrough>;
const auto elementwise_kernel = kernel_elementwise<GridwiseElementwise,
CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
Block2TileMap,
CDEElementwiseOperation>;
return LaunchKernel(gemm_kernel,
elementwise_kernel,
arg,
dev_gemm_args,
dev_gemm_workspace,
stream_config);
}
template <typename KernelFunction, typename KernelFunction2>
float LaunchKernel(const KernelFunction& gemm_kernel,
const KernelFunction2& elementwise_kernel,
const Argument& arg,
const void* dev_gemm_args,
[[maybe_unused]] void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
float time{0.f};
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
};
// GEMM kernel
time = launch_and_time_kernel_with_preprocess(
stream_config,
preprocess,
gemm_kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
arg.group_count_,
arg.a_element_op_,
arg.b_element_op_,
PassThrough{});
// Elementwise kernels
for(int i = 0; i < arg.group_count_; ++i)
{
time += launch_and_time_kernel(
stream_config,
elementwise_kernel,
dim3(arg.group_grid_size_[i]),
dim3(BlockSize),
0,
concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
arg.elementwise_d_grid_descs_m_n_[i]),
make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
arg.ds_grid_pointer_[i]),
type_convert<EDataType*>(arg.e_ptrs_[i]),
Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0),
arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)},
arg.cde_element_op_);
}
return time;
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
return false;
}
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
if(not group_arg_valid)
{
#if DEBUG_LOG
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print();
#endif // DEBUG_LOG
}
supported = supported && group_arg_valid;
}
return supported;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op)
{
return Argument{p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) override
{
return std::make_unique<Argument>(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage"
<< "<"
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << ","
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< ">";
// clang-format on
return str.str();
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
arg.gemm_kernel_args_.data(),
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
void SetWorkSpacePointer(
BaseArgument* p_arg,
void* p_workspace,
[[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
p_arg_->UpdateEPointers();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -23,6 +23,7 @@ namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
bool Zeroing,
typename ALayout,
typename BLayout,
typename DsLayout,
@@ -106,33 +107,63 @@ __global__ void
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
auto barrier_count_finished =
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
if constexpr(Zeroing)
{
auto barrier_count_finished =
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
barrier_count_finished,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
barrier_count_finished,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
nullptr,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
}
id_off += grid_size_grp;
id_local += grid_size_grp;
@@ -193,8 +224,11 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeType = ADataType,
typename ALDSType = ComputeType,
typename BLDSType = ComputeType>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using AComputeType = ComputeType;
using BComputeType = ComputeType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeType,
AComputeType,
BComputeType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
LoopSched,
PipelineVer,
ALDSType,
BLDSType>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
@@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
DsDataType,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
if(arg.k_batch_ == 1)
{
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
false,
ALayout,
BLayout,
DsLayout,
ELayout,
DsDataType,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
reinterpret_cast<uint32_t*>(arg.p_workspace_),
arg.barrier_size_grp_,
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.k_batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
nullptr,
arg.barrier_size_grp_,
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.k_batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
true,
ALayout,
BLayout,
DsLayout,
ELayout,
DsDataType,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
reinterpret_cast<uint32_t*>(arg.p_workspace_),
arg.barrier_size_grp_,
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.k_batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
};
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
// in IsSupportedArgument function
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
// enforced in IsSupportedArgument function
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{
if(has_main_k_block_loop)
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
// If we use padding we do not support vector loads for dimensions not divisible by
// vector load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
// layout, thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -26,13 +26,19 @@ namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count)
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
@@ -64,10 +70,16 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
gemm_desc_ptr[group_id].block_2_ctile_map_,
a_element_op,
b_element_op,
c_element_op);
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct GemmTransKernelArg
{
KernelArgument karg_;
@@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
arg.gemm_kernel_args_.size(),
PassThrough{},
PassThrough{},
PassThrough{});
};
if(all_have_main_k0_block_loop)

View File

@@ -92,6 +92,110 @@ struct Add
};
};
struct Max
{
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
const Y x0_converted = type_convert<Y>(x0);
const Y x1_converted = type_convert<Y>(x1);
y = ck::math::max(x0_converted, x1_converted);
}
};
struct Min
{
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
const Y x0_converted = type_convert<Y>(x0);
const Y x1_converted = type_convert<Y>(x1);
y = ck::math::min(x0_converted, x1_converted);
}
};
struct Multiply
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
y = x0 * x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 * x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 * type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 * x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) * x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 * x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 * x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x0 * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 * x1;
};
};
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// y = UnaryOp0(UnaryOp1(...(x)))
template <typename... UnaryOpsSet>
struct UnaryCombinedOp
{
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// Execute first unary op to copy data to y
unary_ops_.At(Number<0>{})(y, x);
static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
};
Tuple<UnaryOpsSet...> unary_ops_;
};
// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
struct BinaryWithUnaryCombinedOp
{
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
UnaryOp0 unary_op0,
UnaryOp1 unary_op1)
: binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
{
}
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
Y unary_x0_tmp_result;
Y unary_x1_tmp_result;
unary_op0_(unary_x0_tmp_result, x0);
unary_op1_(unary_x1_tmp_result, x1);
binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
};
private:
BinaryOp binary_op_;
UnaryOp0 unary_op0_;
UnaryOp1 unary_op1_;
};
// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
template <typename BinaryOp0,
typename BinaryOp1,
typename UnaryOp0,
typename UnaryOp1,
typename UnaryOp2>
struct TrinaryWithUnaryCombinedOp
{
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
BinaryOp0 binary_op1,
UnaryOp0 unary_op0,
UnaryOp1 unary_op1,
UnaryOp2 unary_op2)
: binary_op0_(binary_op0),
binary_op1_(binary_op1),
unary_op0_(unary_op0),
unary_op1_(unary_op1),
unary_op2_(unary_op2)
{
}
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
{
Y unary_x0_tmp_result;
Y unary_x1_tmp_result;
Y unary_x2_tmp_result;
unary_op0_(unary_x0_tmp_result, x0);
unary_op1_(unary_x1_tmp_result, x1);
unary_op2_(unary_x2_tmp_result, x2);
binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
};
private:
BinaryOp0 binary_op0_{};
BinaryOp1 binary_op1_{};
UnaryOp0 unary_op0_{};
UnaryOp1 unary_op1_{};
UnaryOp2 unary_op2_{};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck

View File

@@ -12,10 +12,6 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
#if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThroughPack2
{
template <typename Y, typename X>
@@ -449,11 +445,7 @@ struct FastGelu
const float u = x * (c1 * x * x + c2);
const float emu = __expf(u);
#if !CK_WORKAROUND_SWDEV_383542
y = x * __frcp_rn(1.f + emu);
#else
y = x * __ocml_native_recip_f32(1.f + emu);
#endif
y = x * ck::math::rcp(1.f + emu);
}
template <>
@@ -559,6 +551,244 @@ struct TanH
};
};
struct ACos
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::acos(x);
};
};
struct Neg
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::neg(x);
};
};
struct ATan
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::atan(x);
};
};
struct Sin
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::sin(x);
};
};
struct ASinH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::asinh(x);
};
};
struct Cos
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::cos(x);
};
};
struct ACosH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::acosh(x);
};
};
struct Tan
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tan(x);
};
};
struct ATanH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::atanh(x);
};
};
struct SinH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::sinh(x);
};
};
struct Ceil
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::ceil(x);
};
};
struct Exp
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::exp(x);
};
};
struct CosH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::cosh(x);
};
};
struct Floor
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::floor(x);
};
};
struct Log
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::log(x);
};
};
struct ASin
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::asin(x);
};
};
struct Rcp
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::rcp(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}

View File

@@ -118,8 +118,16 @@ struct GridwiseElementwise
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
const index_t m1_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
const auto thread_grid_offset =
make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
const auto input_thread_grid_offset = generate_tuple(
[&](auto) {
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
},
Number<NumInput>{});
const auto output_thread_grid_offset = generate_tuple(
[&](auto) {
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
},
Number<NumOutput>{});
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// If src and dst have same vector dim, then:
@@ -157,9 +165,9 @@ struct GridwiseElementwise
uniform_sequence_gen_t<NumOutput, 1>,
uniform_sequence_gen_t<NumInput, false>,
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
thread_grid_offset,
input_thread_grid_offset,
out_grid_desc_tuple,
thread_grid_offset,
output_thread_grid_offset,
elementwise_op};
global_to_global_transfer.Run(
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -30,7 +30,7 @@ namespace ck {
// D0, D1, ... and E have the same layout
template <typename AsDataType,
typename BsDataType,
typename ComputeDataType_,
typename AComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -71,7 +71,8 @@ template <typename AsDataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType_ = AComputeDataType_>
struct GridwiseGemmMultipleABD_xdl_cshuffle
{
static constexpr index_t NumATensor = AsDataType::Size();
@@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using ComputeDataType =
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else
using ComputeDataType = ComputeDataType_;
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ComputeDataType),
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
@@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<ComputeDataType>,
Tuple<AComputeDataType>,
decltype(as_grid_desc_ak0_m_ak1),
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
@@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<ComputeDataType>,
Tuple<BComputeDataType>,
decltype(bs_grid_desc_bk0_n_bk1),
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
@@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeDataType, // ComputeDataType for A
ComputeDataType, // ComputeDataType for B
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
@@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -73,7 +73,7 @@ template <typename ADataType,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType = AComputeDataType_>
typename BComputeDataType_ = AComputeDataType_>
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#if CK_WORKAROUND_DENORM_FIX
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()

View File

@@ -31,7 +31,8 @@ namespace ck {
// D0, D1, ... and E have the same layout
template <typename ADataType,
typename BDataType,
typename ComputeType,
typename AComputeType,
typename BComputeType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -71,7 +72,9 @@ template <typename ADataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer,
typename ALDSType,
typename BLDSType>
struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ComputeType),
return math::max(a_block_space_size_aligned * sizeof(ALDSType) +
b_block_space_size_aligned * sizeof(BLDSType),
c_block_size * sizeof(CShuffleDataType));
}
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumDTensor_,
typename DsDataType_,
bool Zeroing,
typename AGridDesc_KBatch_AK0_M_AK1,
typename BGridDesc_KBatch_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
ALDSType,
decltype(a_grid_desc_kbatch_ak0_m_ak1),
decltype(a_block_desc_kbatch_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
BLDSType,
decltype(b_grid_desc_kbatch_bk0_n_bk1),
decltype(b_block_desc_kbatch_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
ALDSType,
BLDSType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
@@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
LoopSched,
AComputeType,
BComputeType>();
#if 1
if(block_work_idx[I0] == 0)
if constexpr(Zeroing)
{
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
const index_t numNThreads = NPerBlock / nThreadSize;
const index_t numMThreads = BlockSize / numNThreads;
const index_t mThreadSize = MPerBlock / numMThreads;
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
EDataType,
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
true>
e_thread_zero_buf;
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
EDataType,
EDataType,
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mThreadSize, 1, nThreadSize>,
Sequence<0, 1, 2, 3>,
3,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1],
m_tid * mThreadSize,
block_work_idx[I2],
n_tid * nThreadSize),
ck::tensor_operation::element_wise::PassThrough{}};
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_zero_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
__syncthreads();
if(threadIdx.x == 0)
if(block_work_idx[I0] == 0)
{
atomicAdd(barrier_count_finished, 1);
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
const index_t numNThreads = NPerBlock / nThreadSize;
const index_t numMThreads = BlockSize / numNThreads;
const index_t mThreadSize = MPerBlock / numMThreads;
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
EDataType,
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
true>
e_thread_zero_buf;
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
EDataType,
EDataType,
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mThreadSize, 1, nThreadSize>,
Sequence<0, 1, 2, 3>,
3,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1],
m_tid * mThreadSize,
block_work_idx[I2],
n_tid * nThreadSize),
ck::tensor_operation::element_wise::PassThrough{}};
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_zero_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
__builtin_amdgcn_s_barrier();
if(threadIdx.x == 0)
{
atomicAdd(barrier_count_finished, 1);
}
}
}
#endif
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ALDSType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
static_cast<BLDSType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
@@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// shuffle C and write out
{
if(threadIdx.x == 0)
if constexpr(Zeroing)
{
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
if(threadIdx.x == 0)
{
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
}
__builtin_amdgcn_s_barrier();
}
__syncthreads();
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
@@ -951,18 +960,131 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
});
if(threadIdx.x == 0)
if constexpr(Zeroing)
{
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
if(k_id_finished_t == KBatch)
if(threadIdx.x == 0)
{
*barrier_count_finished = 0;
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
if(k_id_finished_t == KBatch)
{
*barrier_count_finished = 0;
}
}
}
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap>
__device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
uint32_t* barrier_count_finished,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
if(kbatch_id == KBatch - 1)
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, true>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
barrier_count_finished,
KBatch,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
else
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>, true>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
barrier_count_finished,
KBatch,
a_element_op,
b_element_op,
ck::tensor_operation::element_wise::PassThrough{},
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
@@ -976,7 +1098,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
uint32_t* barrier_count_finished,
uint32_t*,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
@@ -1028,49 +1150,22 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
if(kbatch_id == KBatch - 1)
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
barrier_count_finished,
KBatch,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
else
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
barrier_count_finished,
KBatch,
a_element_op,
b_element_op,
ck::tensor_operation::element_wise::PassThrough{},
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, false>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
nullptr,
KBatch,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -14,6 +14,10 @@
namespace ck {
namespace math {
#if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); };
@@ -111,6 +115,276 @@ inline __host__ double tanh<double>(double x)
return std::tanh(x);
};
template <typename T>
inline __host__ T acos(T x)
{
return ck::type_convert<T>(std::acosf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float acos<float>(float x)
{
return std::acosf(x);
};
template <>
inline __host__ double acos<double>(double x)
{
return std::acos(x);
};
template <typename T>
inline __host__ T neg(T x)
{
return ck::type_convert<T>(-(ck::type_convert<float>(x)));
};
template <>
inline __host__ float neg<float>(float x)
{
return -x;
};
template <>
inline __host__ double neg<double>(double x)
{
return -x;
};
template <>
inline __host__ int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
inline __host__ int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <typename T>
inline __host__ T atan(T x)
{
return ck::type_convert<T>(std::atanf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float atan<float>(float x)
{
return std::atanf(x);
};
template <>
inline __host__ double atan<double>(double x)
{
return std::atan(x);
};
template <typename T>
inline __host__ T sin(T x)
{
return ck::type_convert<T>(std::sinf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float sin<float>(float x)
{
return std::sinf(x);
};
template <>
inline __host__ double sin<double>(double x)
{
return std::sin(x);
};
template <typename T>
inline __host__ T asin(T x)
{
return ck::type_convert<T>(std::asinf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float asin<float>(float x)
{
return std::asinf(x);
};
template <>
inline __host__ double asin<double>(double x)
{
return std::asin(x);
};
template <typename T>
inline __host__ T asinh(T x)
{
return ck::type_convert<T>(std::asinhf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float asinh<float>(float x)
{
return std::asinhf(x);
};
template <>
inline __host__ double asinh<double>(double x)
{
return std::asinh(x);
};
template <typename T>
inline __host__ T cos(T x)
{
return ck::type_convert<T>(std::cosf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float cos<float>(float x)
{
return std::cosf(x);
};
template <>
inline __host__ double cos<double>(double x)
{
return std::cos(x);
};
template <typename T>
inline __host__ T acosh(T x)
{
return ck::type_convert<T>(std::acoshf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float acosh<float>(float x)
{
return std::acoshf(x);
};
template <>
inline __host__ double acosh<double>(double x)
{
return std::acosh(x);
};
template <typename T>
inline __host__ T tan(T x)
{
return ck::type_convert<T>(std::tanf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float tan<float>(float x)
{
return std::tanf(x);
};
template <>
inline __host__ double tan<double>(double x)
{
return std::tan(x);
};
template <typename T>
inline __host__ T atanh(T x)
{
return ck::type_convert<T>(std::atanhf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float atanh<float>(float x)
{
return std::atanhf(x);
};
template <>
inline __host__ double atanh<double>(double x)
{
return std::atanh(x);
};
template <typename T>
inline __host__ T sinh(T x)
{
return ck::type_convert<T>(std::sinhf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float sinh<float>(float x)
{
return std::sinhf(x);
};
template <>
inline __host__ double sinh<double>(double x)
{
return std::sinh(x);
};
template <typename T>
inline __host__ T ceil(T x)
{
return ck::type_convert<T>(std::ceilf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float ceil<float>(float x)
{
return std::ceilf(x);
};
template <>
inline __host__ double ceil<double>(double x)
{
return std::ceil(x);
};
template <typename T>
inline __host__ T cosh(T x)
{
return ck::type_convert<T>(std::coshf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float cosh<float>(float x)
{
return std::coshf(x);
};
template <>
inline __host__ double cosh<double>(double x)
{
return std::cosh(x);
};
template <typename T>
inline __host__ T floor(T x)
{
return ck::type_convert<T>(std::floorf(ck::type_convert<float>(x)));
};
template <>
inline __host__ float floor<float>(float x)
{
return std::floorf(x);
};
template <>
inline __host__ double floor<double>(double x)
{
return std::floor(x);
};
template <typename T>
inline __host__ T rcp(T x)
{
return ck::type_convert<T>(1.f / ck::type_convert<float>(x));
};
template <typename T>
inline __host__ T exp(T x)
{
@@ -282,6 +556,286 @@ inline __device__ double tanh<double>(double x)
return ::tanh(x);
};
template <typename T>
inline __device__ T acos(T x)
{
return ck::type_convert<T>(::acosf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float acos<float>(float x)
{
return ::acosf(x);
};
template <>
inline __device__ double acos<double>(double x)
{
return ::acos(x);
};
template <typename T>
inline __device__ T neg(T x)
{
return ck::type_convert<T>(-(ck::type_convert<float>(x)));
};
template <>
inline __device__ float neg<float>(float x)
{
return -x;
};
template <>
inline __device__ double neg<double>(double x)
{
return -x;
};
template <>
inline __device__ int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
inline __device__ int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <>
inline __device__ half_t neg<half_t>(half_t x)
{
return __hneg(x);
};
template <typename T>
inline __device__ T atan(T x)
{
return ck::type_convert<T>(::atanf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float atan<float>(float x)
{
return ::atanf(x);
};
template <>
inline __device__ double atan<double>(double x)
{
return ::atan(x);
};
template <typename T>
inline __device__ T sin(T x)
{
return ck::type_convert<T>(::sinf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float sin<float>(float x)
{
return ::sinf(x);
};
template <>
inline __device__ double sin<double>(double x)
{
return ::sin(x);
};
template <>
inline __device__ half_t sin<half_t>(half_t x)
{
return ::hsin(x);
};
template <typename T>
inline __device__ T asin(T x)
{
return ck::type_convert<T>(::asinf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float asin<float>(float x)
{
return ::asinf(x);
};
template <>
inline __device__ double asin<double>(double x)
{
return ::asin(x);
};
template <typename T>
inline __device__ T asinh(T x)
{
return ck::type_convert<T>(::asinhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float asinh<float>(float x)
{
return ::asinhf(x);
};
template <>
inline __device__ double asinh<double>(double x)
{
return ::asinh(x);
};
template <typename T>
inline __device__ T acosh(T x)
{
return ck::type_convert<T>(::acoshf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float acosh<float>(float x)
{
return ::acoshf(x);
};
template <>
inline __device__ double acosh<double>(double x)
{
return ::acosh(x);
};
template <typename T>
inline __device__ T tan(T x)
{
return ck::type_convert<T>(::tanf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float tan<float>(float x)
{
return ::tanf(x);
};
template <>
inline __device__ double tan<double>(double x)
{
return ::tan(x);
};
template <typename T>
inline __device__ T atanh(T x)
{
return ck::type_convert<T>(::atanhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float atanh<float>(float x)
{
return ::atanhf(x);
};
template <>
inline __device__ double atanh<double>(double x)
{
return ::atanh(x);
};
template <typename T>
inline __device__ T sinh(T x)
{
return ck::type_convert<T>(::sinhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float sinh<float>(float x)
{
return ::sinhf(x);
};
template <>
inline __device__ double sinh<double>(double x)
{
return ::sinh(x);
};
template <typename T>
inline __device__ T ceil(T x)
{
return ck::type_convert<T>(::ceilf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float ceil<float>(float x)
{
return ::ceilf(x);
};
template <>
inline __device__ double ceil<double>(double x)
{
return ::ceil(x);
};
template <>
inline __device__ half_t ceil<half_t>(half_t x)
{
return ::hceil(x);
};
template <typename T>
inline __device__ T cosh(T x)
{
return ck::type_convert<T>(::coshf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float cosh<float>(float x)
{
return ::coshf(x);
};
template <>
inline __device__ double cosh<double>(double x)
{
return ::cosh(x);
};
template <typename T>
inline __device__ T floor(T x)
{
return ck::type_convert<T>(::floorf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float floor<float>(float x)
{
return ::floorf(x);
};
template <>
inline __device__ double floor<double>(double x)
{
return ::floor(x);
};
template <>
inline __device__ half_t floor<half_t>(half_t x)
{
return ::hfloor(x);
};
template <typename T>
inline __device__ T rcp(T x)
{
#if !CK_WORKAROUND_SWDEV_383542
return __frcp_rn(x);
#else
return __ocml_native_recip_f32(x);
#endif
};
template <typename T>
inline __device__ T exp(T x)
{

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"

View File

@@ -7,6 +7,20 @@
namespace ck_tile {
enum struct GenericAttentionMaskEnum
{
NO_MASK = 0,
// below enum could be causal, or sliding window
MASK_FROM_TOP_LEFT = 1,
MASK_FROM_BOTTOM_RIGHT = 2,
// this enum maybe not used by xformer/FA, since it's hard to
// specify left/right window for varlen case. put it here for
// debug purpose
MASK_GENERIC,
};
// clang-format off
/* generic Attention Mask Coordinate
use x(horizontal axis), y(vertical axis) to describe mask.
@@ -188,6 +202,129 @@ struct GenericAttentionMask
index_t y_total, x_total;
};
// clang-format off
namespace impl {
template <bool IsMasking_> struct SimplifiedMaskName;
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
}
// clang-format on
// this version only have 2 variation: masking and non-masking
// This is more friendly to codegen (e.g. need generate less kernel)
// ... with the trade-off that may have more instruction in causal mode
template <bool IsMasking_ = true>
struct SimplifiedGenericAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
{
}
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
// TODO: x_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
return ck_tile::make_tuple(x_start, x_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
return i_x >= x_total;
}
else
{
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = i_y + x; // this could be larger than x_total, but it's fine
return i_x < x_start || i_x >= x_end;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
// TODO: no need to check begin
return (i_x + TileWidth) > x_total;
}
else
{
// check top-right corner > x or left-borrom corner < x
index_t i_x_end = i_x + TileWidth;
index_t i_y_end = i_y + TileHeight;
// index_t x_end = min(i_y + x, x_total);
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > (i_x + y);
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge;
}
}
private:
index_t y, x;
index_t y_total, x_total;
};
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
@@ -199,29 +336,32 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t x_total,
bool is_top_left = true)
{
index_t x = 0, y = 0;
// TODO: below should all use sgpr arithmetic
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
if(is_top_left)
{
if(left_size < 0)
left_size = y_total - 1;
if(right_size < 0)
right_size = x_total - 1;
left_size = left_size < 0 ? left_size_tmp : left_size;
right_size = right_size < 0 ? right_size_tmp : right_size;
x = 1 + right_size;
y = left_size + 1;
}
else
{
if(left_size < 0)
left_size = x_total - 1;
if(right_size < 0)
right_size = y_total - 1;
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
x = x_total - y_total + 1 + right_size;
y = y_total - x_total + 1 + left_size;
}
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
}
} // namespace ck_tile

View File

@@ -157,7 +157,9 @@ struct FmhaFwdKernel
struct FmhaFwdMaskKargs
{
ck_tile::index_t mask_y, mask_x;
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdCommonLSEKargs
@@ -227,8 +229,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
// QElementFunction q_element_func,
// KElementFunction k_element_func,
// VElementFunction v_element_func,
@@ -285,8 +288,9 @@ struct FmhaFwdKernel
}
if constexpr(kHasMask)
{
kargs.mask_y = mask_y;
kargs.mask_x = mask_x;
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
{
@@ -324,8 +328,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
// QElementFunction q_element_func,
// KElementFunction k_element_func,
// VElementFunction v_element_func,
@@ -380,8 +385,9 @@ struct FmhaFwdKernel
}
if constexpr(kHasMask)
{
kargs.mask_y = mask_y;
kargs.mask_x = mask_x;
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
{
@@ -665,7 +671,12 @@ struct FmhaFwdKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k};
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaPipelineEnum
{
QRKSVS = 0,
QRKSVS_ASYNC,
QRKSVS_FP8,
QSKSVS,
};
} // namespace ck_tile

View File

@@ -0,0 +1,110 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <index_t NumATensors, typename ADataType, typename BDataType, typename ElementOp>
struct ReferenceElementwise : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const std::array<Tensor<ADataType>, NumATensors>& a_tensors,
Tensor<BDataType>& b_tensor,
ElementOp element_op)
: a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op}
{
}
const std::array<Tensor<ADataType>, NumATensors>& a_tensors_;
Tensor<BDataType>& b_tensor_;
ElementOp element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceElementwise::Argument;
float Run(const Argument& arg)
{
if constexpr(NumATensors == 1)
{
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
arg.element_op_(self(idx), arg.a_tensors_[0](idx));
});
}
else if constexpr(NumATensors == 2)
{
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx));
});
}
else if constexpr(NumATensors == 3)
{
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
arg.element_op_(self(idx),
arg.a_tensors_[0](idx),
arg.a_tensors_[1](idx),
arg.a_tensors_[2](idx));
});
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const std::array<Tensor<ADataType>, NumATensors>& a_tensors,
Tensor<BDataType>& b_tensor,
ElementOp element_op)
{
return Argument{a_tensors, b_tensor, element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceElementwise"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,175 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// assumption: every D matrix has the same layout and the same datatype
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemmMultipleD : public device::BaseOperator
{
using DDataType = remove_cvref_t<tuple_element_t<0, DsDataType>>;
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
ds_m_n_{ds_m_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmMultipleD::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
ComputeTypeA v_a = 0;
ComputeTypeB v_b = 0;
for(int k = 0; k < K; ++k)
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
if constexpr(DsDataType::Size() == 0)
{
arg.cde_element_op_(v_c, v_acc);
}
else if constexpr(DsDataType::Size() == 1)
{
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n));
}
else if constexpr(DsDataType::Size() == 2)
{
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n));
}
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_m_k, b_k_n, ds_m_n, c_m_n, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmMultipleD"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck

View File

@@ -12,398 +12,21 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef DL_KERNELS
#include "gemm_dl.inc"
#endif
#ifdef CK_USE_WMMA
#include "gemm_wmma.inc"
#endif
#ifdef CK_USE_XDL
#include "gemm_xdl.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS)
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS)
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS)
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP64
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
template <typename ALayout,
typename BLayout,
typename CLayout,
@@ -435,16 +58,137 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef DL_KERNELS
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
}
}
#endif
#endif // DL_KERNELS
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
#endif
#ifdef CK_USE_XDL
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
op_ptrs);
@@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
op_ptrs);
@@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
op_ptrs);
@@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
op_ptrs);
@@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
op_ptrs);
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
/// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
@@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory<
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
#endif
}
}
#endif
@@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
#endif
return op_ptrs;
}

View File

@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
@@ -69,7 +69,7 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
@@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#endif
#ifdef CK_ENABLE_INT8
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
{

View File

@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_ENABLE_FP16)
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(CK_ENABLE_FP32)
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(CK_ENABLE_INT8)
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,34 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,238 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_INT8
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP64
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -10,439 +10,18 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_USE_XDL
#include "grouped_convolution_backward_data_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_backward_data_wmma.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
@@ -488,9 +67,10 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
@@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory<
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
@@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory<
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
else if constexpr(NumDimSpatial == 3)
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
@@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory<
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
#endif
}
}
#endif
#ifdef CK_USE_WMMA
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
op_ptrs);
@@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
op_ptrs);
@@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory<
#endif
}
}
#endif
return op_ptrs;
}

View File

@@ -0,0 +1,243 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,216 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

Some files were not shown because too many files have changed in this diff Show More