mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Merge branch 'ck_tile/refactor' into ck_tile/elementwise
This commit is contained in:
@@ -145,6 +145,22 @@ if(GPU_TARGETS)
|
||||
else()
|
||||
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
|
||||
endif()
|
||||
|
||||
if (GPU_TARGETS)
|
||||
if (GPU_TARGETS MATCHES "gfx9")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx11")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
|
||||
find_package(hip)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
add_custom_target(client_gemm_fastgelu_examples)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_custom_target(client_gemm_fastgelu_examples)
|
||||
|
||||
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
|
||||
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
|
||||
client_gemm_fastgelu)
|
||||
|
||||
add_custom_target(client_gemm_fastgelu_generic_examples)
|
||||
add_custom_target(client_gemm_fastgelu_generic_examples)
|
||||
|
||||
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
|
||||
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
|
||||
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
|
||||
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
|
||||
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
|
||||
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
|
||||
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
|
||||
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
|
||||
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
|
||||
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
|
||||
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
|
||||
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_quantization gemm_quantization.cpp)
|
||||
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_quantization gemm_quantization.cpp)
|
||||
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
|
||||
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
|
||||
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
|
||||
|
||||
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
|
||||
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
|
||||
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -17,6 +17,11 @@ if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
@@ -95,7 +95,8 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck::index_t NumNonSpatialDim = 3,
|
||||
typename ComputeType = InDataType>
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType>
|
||||
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
|
||||
@@ -186,7 +187,8 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ComputeType>;
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
50
client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp
Normal file
50
client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using OutDataType = ck::f8_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using AComputeType = ck::f8_t;
|
||||
using BComputeType = ck::bf8_t;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3,
|
||||
AComputeType,
|
||||
BComputeType>(
|
||||
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -1,2 +1,4 @@
|
||||
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES))
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
# Fwd scaleadd scaleadd relu
|
||||
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32
|
||||
grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp)
|
||||
@@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke
|
||||
add_executable(client_grouped_convnd_bwd_data_scale_fp16
|
||||
grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp)
|
||||
target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
@@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
|
||||
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
|
||||
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
|
||||
GPU_TARGETS MATCHES "gfx942")
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
|
||||
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)
|
||||
|
||||
@@ -48,7 +48,10 @@ else()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
|
||||
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
find_package(composable_kernel COMPONENTS device_contraction_operations)
|
||||
endif()
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
|
||||
@@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
@@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
# FIXME: re-enable this example as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
@@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
|
||||
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
|
||||
@@ -1,20 +1,3 @@
|
||||
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
|
||||
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
|
||||
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
|
||||
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,29 +1,20 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_gemm_add_add_fastgelu_xdl)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
add_custom_target(example_gemm_add_add_fastgelu_xdl)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(gpu_list "")
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
|
||||
83
example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp
Normal file
83
example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BComputeType = ck::bf8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -1,25 +1,17 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
# dlops
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
# xdlops
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
|
||||
@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
||||
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
|
||||
@@ -0,0 +1,394 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType, DDataType>;
|
||||
using EDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using DsLayout = ck::Tuple<DLayout, DLayout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr int NumDMatrices = 2;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
int k_batch = 128;
|
||||
bool time_kernel = true;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDMatrices>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDMatrices>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
d_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDMatrices> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
gemm_descs.push_back({problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
problem_size.stride_Ds[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
gemm.SetKBatchSize(argument, config.k_batch);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = argument.gemm_kernel_args_[i].karg_;
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{}));
|
||||
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(),
|
||||
c_device_result_tensors[i].mDesc.GetElementSize() *
|
||||
sizeof(EDataType));
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
|
||||
std::istringstream in(input);
|
||||
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 11)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "arg10: k_batch (> 0)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[10]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -36,7 +36,7 @@ using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F32;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
problem_size.Ns.push_back(256);
|
||||
problem_size.Ks.push_back(128);
|
||||
problem_size.Ms.push_back(128 + rand() % 128);
|
||||
problem_size.Ns.push_back(1024);
|
||||
problem_size.Ks.push_back(1024);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
|
||||
@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ADataType = F16;
|
||||
using BDataType = F8;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -1,48 +1,41 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_gemm_reduce_xdl)
|
||||
add_custom_target(example_gemm_reduce_xdl_max)
|
||||
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
|
||||
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
|
||||
add_custom_target(example_gemm_reduce_xdl)
|
||||
add_custom_target(example_gemm_reduce_xdl_max)
|
||||
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
|
||||
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
|
||||
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
|
||||
|
||||
add_example_dependencies(example_gemm_reduce_xdl
|
||||
example_gemm_reduce_xdl_mean_meansquare
|
||||
example_gemm_reduce_xdl_max
|
||||
example_gemm_add_add_mean_meansquare_xdl)
|
||||
add_example_dependencies(example_gemm_reduce_xdl
|
||||
example_gemm_reduce_xdl_mean_meansquare
|
||||
example_gemm_reduce_xdl_max
|
||||
example_gemm_add_add_mean_meansquare_xdl)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
|
||||
endif()
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
|
||||
@@ -1,29 +1,15 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
set(target 1)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
|
||||
|
||||
@@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear)
|
||||
|
||||
# FP32
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
|
||||
|
||||
# FP64
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
|
||||
|
||||
# FP16
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
|
||||
|
||||
# BF16
|
||||
add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
|
||||
|
||||
add_dependencies(example_contraction example_contraction_scale)
|
||||
add_dependencies(example_contraction example_contraction_bilinear)
|
||||
add_example_dependencies(example_contraction example_contraction_scale)
|
||||
add_example_dependencies(example_contraction example_contraction_bilinear)
|
||||
|
||||
@@ -1,5 +1,2 @@
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
|
||||
@@ -1,40 +1,23 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
|
||||
add_custom_target(example_grouped_conv_fwd_multiple_d)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_fwd_multiple_d)
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
|
||||
|
||||
add_custom_target(example_gemm_scale_softmax_gemm)
|
||||
|
||||
|
||||
@@ -1,32 +1,23 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_splitK_gemm_xdl)
|
||||
add_custom_target(example_splitK_gemm_xdl)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
|
||||
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
|
||||
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
@@ -1,27 +1,10 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
# Conv perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + relu perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + relu perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
# Conv perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + relu perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + relu perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
|
||||
@@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
|
||||
add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute elementwise_permute.cpp)
|
||||
if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942"))
|
||||
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp)
|
||||
|
||||
140
example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp
Normal file
140
example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using BinaryAdd = ck::tensor_operation::element_wise::Add;
|
||||
// B = alpha * A0 * A0 + beta * A1 * A1
|
||||
using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise::
|
||||
BinaryWithUnaryCombinedOp<BinaryAdd, UnaryScaleSquare, UnaryScaleSquare>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType, ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
BinaryAddUnaryScaleSquare, // ElementwiseOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8, 8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[3]),
|
||||
1};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 2> as = {Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides)};
|
||||
Tensor<ADataType>& a0 = as[0];
|
||||
Tensor<ADataType>& a1 = as[1];
|
||||
Tensor<BDataType> b(ab_lengths, ab_strides);
|
||||
float alpha = 3.f;
|
||||
float beta = 2.f;
|
||||
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0.mData.data());
|
||||
a1_device_buf.ToDevice(a1.mData.data());
|
||||
|
||||
std::array<const void*, 2> inputs = {a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
|
||||
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths,
|
||||
{ab_strides, ab_strides},
|
||||
{ab_strides},
|
||||
inputs,
|
||||
output,
|
||||
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
|
||||
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
|
||||
std::cout << "B (nchw): " << b.mDesc << std::endl;
|
||||
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<BDataType> host_b(ab_lengths, ab_strides);
|
||||
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as,
|
||||
host_b,
|
||||
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance =
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
|
||||
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
|
||||
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
|
||||
{
|
||||
auto a_val = A_ncdhw(n, c, d, h, w);
|
||||
functor(B_ndhwc(n, d, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -51,32 +39,7 @@ int main()
|
||||
|
||||
std::vector<std::size_t> ncdhw = {16, 8, 8, 8, 8};
|
||||
std::vector<std::size_t> ndhwc = {16, 8, 8, 8, 8};
|
||||
Tensor<ADataType> a(ncdhw);
|
||||
Tensor<BDataType> b(ndhwc);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 5> ab_lengths;
|
||||
/**std::array<ck::index_t, 5> a_strides = {
|
||||
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
|
||||
static_cast<int>(ncdhw[2] * ncdhw[3] * ncdhw[4]),
|
||||
static_cast<int>(ncdhw[3] * ncdhw[4]),
|
||||
static_cast<int>(ncdhw[4]),
|
||||
1};
|
||||
std::array<ck::index_t, 5> b_strides = {
|
||||
static_cast<int>(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]),
|
||||
static_cast<int>(ndhwc[2] * ndhwc[3] * ndhwc[4]),
|
||||
1,
|
||||
static_cast<int>(ndhwc[3] * ndhwc[4]),
|
||||
static_cast<int>(ndhwc[4])};**/
|
||||
|
||||
std::array<ck::index_t, 5> a_strides = {
|
||||
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
|
||||
@@ -93,6 +56,20 @@ int main()
|
||||
1};
|
||||
ck::ranges::copy(ncdhw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -126,10 +103,16 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(ndhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance =
|
||||
ck::Sequence<4>, // InScalarPerVectorSeq
|
||||
ck::Sequence<4>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
|
||||
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
|
||||
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
|
||||
{
|
||||
auto a_val = A_ncdhw(n, c, d, h, w);
|
||||
functor(B_ndhwc(n, d, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -59,10 +47,13 @@ int main()
|
||||
const int W = 5;
|
||||
const int D = 16;
|
||||
|
||||
std::vector<std::size_t> ncdhw = {N, C, D, H, W};
|
||||
std::vector<std::size_t> ndhwc = {N, D, H, W, C};
|
||||
Tensor<ADataType> a(ncdhw);
|
||||
Tensor<BDataType> b(ndhwc);
|
||||
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
|
||||
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
|
||||
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
@@ -74,10 +65,6 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
|
||||
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
|
||||
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -94,11 +81,12 @@ int main()
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
|
||||
std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] *
|
||||
ab_lengths[3] * ab_lengths[4];
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
|
||||
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
|
||||
(sizeof(ADataType) + sizeof(BDataType)) *
|
||||
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -111,10 +99,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(ndhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -55,18 +44,6 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
@@ -77,9 +54,22 @@ int main()
|
||||
1,
|
||||
static_cast<int>(nhwc[2] * nhwc[3]),
|
||||
static_cast<int>(nhwc[3])};
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -111,10 +101,16 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance =
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
const std::vector<std::size_t>& shape_nchw,
|
||||
Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
|
||||
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
|
||||
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
|
||||
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -54,13 +40,16 @@ int main()
|
||||
const int N = 120;
|
||||
const int C = 128;
|
||||
const int H = 32;
|
||||
const int W = 1024;
|
||||
const int W = 32;
|
||||
|
||||
std::vector<std::size_t> nchw = {N, C, H, W};
|
||||
std::vector<std::size_t> nhwc = {N, H, W, C};
|
||||
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
|
||||
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
|
||||
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
@@ -72,11 +61,6 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
|
||||
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -94,10 +78,11 @@ int main()
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
std::size_t flop =
|
||||
std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) *
|
||||
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -110,11 +95,16 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
|
||||
host_b, a, nchw, PassThrough{});
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@
|
||||
#include <random>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -21,11 +23,14 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
std::size_t N = A_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = A_nchw.mDesc.GetLengths()[1];
|
||||
std::size_t H = A_nchw.mDesc.GetLengths()[2];
|
||||
std::size_t W = A_nchw.mDesc.GetLengths()[3];
|
||||
for(std::size_t w = 0; w < W; ++w)
|
||||
for(std::size_t h = 0; h < H; ++h)
|
||||
for(std::size_t c = 0; c < C; ++c)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
|
||||
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -60,8 +48,21 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 8, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
float scale = 1.f;
|
||||
auto i = 0;
|
||||
std::mt19937 gen(11939);
|
||||
@@ -84,22 +85,14 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -113,11 +106,10 @@ int main()
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype =
|
||||
(2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
@@ -129,10 +121,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -20,11 +22,14 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -55,18 +47,6 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
@@ -80,9 +60,29 @@ int main()
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -112,10 +112,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -20,11 +22,14 @@ using F32 = float;
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
std::size_t N = A_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = A_nchw.mDesc.GetLengths()[1];
|
||||
std::size_t H = A_nchw.mDesc.GetLengths()[2];
|
||||
std::size_t W = A_nchw.mDesc.GetLengths()[3];
|
||||
for(std::size_t w = 0; w < W; ++w)
|
||||
for(std::size_t h = 0; h < H; ++h)
|
||||
for(std::size_t c = 0; c < C; ++c)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
|
||||
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {5, 4, 2, 3};
|
||||
std::vector<std::size_t> nhwc = {5, 2, 3, 4};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
std::vector<std::size_t> nchw = {16, 8, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
float scale = 1.f;
|
||||
auto i = 0;
|
||||
@@ -84,22 +86,14 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -129,10 +123,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -20,11 +22,14 @@ using F32 = float;
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -55,18 +47,6 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
@@ -80,9 +60,28 @@ int main()
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -112,10 +111,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
156
example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp
Normal file
156
example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using BinaryAdd = ck::tensor_operation::element_wise::Add;
|
||||
// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2
|
||||
using TrinaryAddUnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp<BinaryAdd,
|
||||
BinaryAdd,
|
||||
UnaryScaleSquare,
|
||||
UnaryScaleSquare,
|
||||
UnaryScaleSquare>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType, ADataType, ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
TrinaryAddUnaryScaleSquare, // ElementwiseOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[3]),
|
||||
1};
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 3> as = {Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides)};
|
||||
Tensor<ADataType>& a0 = as[0];
|
||||
Tensor<ADataType>& a1 = as[1];
|
||||
Tensor<ADataType>& a2 = as[2];
|
||||
Tensor<BDataType> b(ab_lengths, ab_strides);
|
||||
float alpha = 3.f;
|
||||
float beta = 2.f;
|
||||
float gamma = 4.f;
|
||||
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a2.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0.mData.data());
|
||||
a1_device_buf.ToDevice(a1.mData.data());
|
||||
a2_device_buf.ToDevice(a2.mData.data());
|
||||
|
||||
std::array<const void*, 3> inputs = {a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
a2_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
|
||||
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
|
||||
auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths,
|
||||
{ab_strides, ab_strides, ab_strides},
|
||||
{ab_strides},
|
||||
inputs,
|
||||
output,
|
||||
TrinaryAddUnaryScaleSquare{
|
||||
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
|
||||
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
|
||||
std::cout << "A2 (nchw): " << a2.mDesc << std::endl;
|
||||
std::cout << "B (nchw): " << b.mDesc << std::endl;
|
||||
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<BDataType> host_b(ab_lengths, ab_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as,
|
||||
host_b,
|
||||
TrinaryAddUnaryScaleSquare{
|
||||
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
const double threshold = std::pow(2, -10) * 2;
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &= ck::utils::check_err(
|
||||
b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp)
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_im2col_col2im)
|
||||
add_custom_target(example_im2col_col2im)
|
||||
|
||||
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
|
||||
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
|
||||
|
||||
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
|
||||
|
||||
@@ -2,16 +2,9 @@ add_subdirectory(binary)
|
||||
add_subdirectory(multi_AB)
|
||||
add_subdirectory(unary)
|
||||
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl)
|
||||
# ScaleAdd ScaleAdd Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_custom_target(example_convnd_activ_xdl)
|
||||
# ScaleAdd ScaleAdd Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_fpAintB_gemm_wmma)
|
||||
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
|
||||
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
|
||||
endif()
|
||||
add_custom_target(example_fpAintB_gemm_wmma)
|
||||
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
|
||||
add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
|
||||
|
||||
@@ -5,6 +5,12 @@ include_directories(BEFORE
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
|
||||
if(FILE_NAME)
|
||||
add_dependencies(EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endfunction(add_example_dependencies EXAMPLE_NAME)
|
||||
|
||||
function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
|
||||
@@ -56,12 +56,13 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "add bias or not")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left, 2:bottom-right\n"
|
||||
"'t:l,r', top-left local-attn with left right size\n"
|
||||
"'b:l,r', bottom-r local-attn with left right size\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size\n")
|
||||
.insert(
|
||||
"mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left, 2:bottom-right\n"
|
||||
"'t:l,r', top-left sliding window attn with left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn with left right size\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n")
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("lse", "0", "0 not store lse, 1 store lse")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
@@ -357,44 +358,45 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args<FmhaDefaultElementFunctions>{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask.y,
|
||||
mask.x,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{}};
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{}};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
@@ -508,12 +510,32 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
if(lse)
|
||||
{
|
||||
|
||||
@@ -104,169 +104,6 @@ struct FmhaMasks
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
#if 0
|
||||
// internal API, don't use this directly
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t nhead_k,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
float scale,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x)
|
||||
{
|
||||
constexpr bool is_v_rowmajor =
|
||||
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias'
|
||||
/// are 0.
|
||||
// setup stride_* arguments
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if constexpr(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return i_perm ? seqlen_k : nhead_k * seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if constexpr(is_v_rowmajor)
|
||||
return i_perm ? seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * seqlen_k : seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v);
|
||||
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargs(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead / nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
mask_y,
|
||||
mask_x);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargs(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead / nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask_y,
|
||||
mask_x);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// This is the args from caller to underneath API, different from the kernel
|
||||
struct fmha_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t nhead;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
float scale;
|
||||
bool i_perm;
|
||||
bool o_perm;
|
||||
ck_tile::index_t mask_y;
|
||||
ck_tile::index_t mask_x;
|
||||
};
|
||||
#endif
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
template <typename ElementFunctions>
|
||||
struct fmha_fwd_args
|
||||
@@ -306,8 +143,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t mask_y;
|
||||
ck_tile::index_t mask_x;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
// typename ElementFunctions::QElementFunction q_element_func;
|
||||
// typename ElementFunctions::KElementFunction k_element_func;
|
||||
// typename ElementFunctions::VElementFunction v_element_func;
|
||||
@@ -350,8 +188,9 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
// args.q_element_func,
|
||||
// args.k_element_func,
|
||||
// args.v_element_func,
|
||||
@@ -392,8 +231,9 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
// args.q_element_func,
|
||||
// args.k_element_func,
|
||||
// args.v_element_func,
|
||||
@@ -420,6 +260,7 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
bool kHasBias_,
|
||||
bool kStoreLse_,
|
||||
@@ -439,6 +280,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
|
||||
@@ -24,6 +24,16 @@ DTYPE_BITS = {
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
MASK_IMPL = {
|
||||
"generic" : "ck_tile::GenericAttentionMask",
|
||||
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
||||
}
|
||||
|
||||
MASK_SIMPLIFIED_MAP = {
|
||||
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
MASK_MAP = {
|
||||
"no" : "FmhaMasks::NoMask",
|
||||
"causal" : "FmhaMasks::CausalMask",
|
||||
@@ -49,13 +59,16 @@ ELEMENT_FUNC_MAP = {
|
||||
"no" : "FmhaDefaultElementFunctions",
|
||||
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
|
||||
}
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false"
|
||||
}
|
||||
|
||||
MASKS = ["no", "causal", "generic"]
|
||||
DIRECTIONS = ["fwd"]
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
|
||||
@@ -130,7 +143,8 @@ using fmha_kernel_{F_idx} =
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -168,17 +182,40 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
MASK_CHECK_MAP = {
|
||||
"no" : "t.mask_type == mask_enum::no_mask",
|
||||
"causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right",
|
||||
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic" : "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no" : "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
def get_mask_map(mask : str):
|
||||
if mask == "generic":
|
||||
return MASK_MAP
|
||||
elif mask == "simplified":
|
||||
return MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
def get_mask_check_map(mask : str):
|
||||
if mask == "generic":
|
||||
return MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
@@ -212,14 +249,19 @@ class FmhaFwdApiTrait:
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
@@ -228,7 +270,7 @@ class FmhaFwdApiTrait:
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {self.bk0blen} == 0'
|
||||
else: assert False
|
||||
|
||||
@@ -239,7 +281,7 @@ class FmhaFwdApiTrait:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {self.bk0blen} == 0'
|
||||
else: assert False
|
||||
|
||||
@@ -270,13 +312,17 @@ class FmhaFwdPipeline:
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_bias == 't' : n += '_bias'
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self):
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
@@ -298,8 +344,9 @@ class FmhaFwdApiPool:
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask],
|
||||
F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
||||
@@ -341,6 +388,7 @@ class FmhaFwdKernel:
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
F_element_func : str
|
||||
|
||||
@property
|
||||
@@ -369,8 +417,9 @@ class FmhaFwdKernel:
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_occupancy = self.F_tile.F_occupancy ,
|
||||
F_mask = MASK_MAP[self.F_pipeline.F_mask],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
|
||||
@@ -426,14 +475,17 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask))
|
||||
@@ -446,16 +498,19 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
|
||||
@@ -474,6 +529,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_element_func=element_func)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
@@ -489,24 +545,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None:
|
||||
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
api_pool, kernels = get_blobs(kernel_filter)
|
||||
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_api(api_pool, output_dir)
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None:
|
||||
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_blobs(kernel_filter)
|
||||
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -535,8 +591,26 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
||||
" 1: generate more instance to cover all hdim"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, args.filter)
|
||||
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, args.filter)
|
||||
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
causal_top_left,
|
||||
causal_bottom_right,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
@@ -21,18 +22,19 @@ struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::causal_top_left)
|
||||
os << "tl";
|
||||
else if(type == mask_enum::causal_bottom_right)
|
||||
os << "br";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "tl(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "br(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << "/" << x << ")";
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
@@ -57,22 +59,30 @@ struct mask_info
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -84,15 +94,19 @@ struct mask_info
|
||||
{
|
||||
// should be 0, 1, 2
|
||||
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
|
||||
if(tmp.type == mask_enum::causal_top_left)
|
||||
if(tmp.type == mask_enum::mask_top_left)
|
||||
{
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else if(tmp.type == mask_enum::causal_bottom_right)
|
||||
else if(tmp.type == mask_enum::mask_bottom_right)
|
||||
{
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
}
|
||||
return tmp;
|
||||
|
||||
@@ -23,7 +23,8 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -45,6 +45,10 @@
|
||||
#endif
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
@@ -62,8 +66,7 @@
|
||||
// buffer resource
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__)
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx103__)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
@@ -75,8 +78,7 @@
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
|
||||
#define CK_USE_AMD_V_MAC_F32
|
||||
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
|
||||
defined(__gfx94__) // for GPU code
|
||||
#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||
@@ -89,7 +91,7 @@
|
||||
// MFMA instruction
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_MFMA
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_USE_AMD_MFMA
|
||||
#endif
|
||||
|
||||
@@ -120,7 +122,7 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation CDE elementwise operation.
|
||||
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
|
||||
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
|
||||
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeType =
|
||||
typename AComputeType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>())> // ComputeType is InputType by default (first
|
||||
ADataType>()), // AComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Structure representing single GEMM problem arguments.
|
||||
///
|
||||
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
|
||||
/// point kernel.
|
||||
///
|
||||
/// @tparam NumDTensor The number of D input tensors.
|
||||
///
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GroupedGemmMultipleDKernelArguments
|
||||
{
|
||||
__host__ __device__
|
||||
GroupedGemmMultipleDKernelArguments(const void* p_a_grid_,
|
||||
const void* p_b_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
void* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_ds_grid{p_ds_grid_},
|
||||
p_e_grid{p_e_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_}
|
||||
{
|
||||
}
|
||||
|
||||
const void* p_a_grid;
|
||||
const void* p_b_grid;
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
void* p_e_grid;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::stringstream str;
|
||||
for(auto sd : StrideDs)
|
||||
str << sd << ",";
|
||||
|
||||
std::cout << "arg {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SE:" << StrideE << ", "
|
||||
<< "SDs: {" << str.str() << "}"
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the k batch size.
|
||||
///
|
||||
/// @param p_arg Pointer to the Argument we're going to change.
|
||||
/// @param[in] kbatch The kbatch value.
|
||||
///
|
||||
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the device kernel arguments pointer.
|
||||
///
|
||||
/// @param p_arg The pointer to the Argument we're going to update.
|
||||
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
|
||||
/// arguments.
|
||||
///
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Gets the device kernel argument size.
|
||||
///
|
||||
/// @param[in] p_arg The pointer to the Device op Argument.
|
||||
///
|
||||
/// @return The device kernel argument size.
|
||||
///
|
||||
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,10 +22,12 @@ namespace device {
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim,
|
||||
index_t MPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
index_t NumDim, // The max dim of input tensors
|
||||
// the tensors descs have to be aligned, such that
|
||||
// the innermost dim is the contiguous one.
|
||||
index_t MPerThread, // How many elements per thread to read
|
||||
typename InScalarPerVectorSeq, // Scalar per vec for each Input
|
||||
typename OutScalarPerVectorSeq> // Scalar per vec for each Output
|
||||
struct DeviceElementwiseImpl
|
||||
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
|
||||
{
|
||||
@@ -242,13 +244,13 @@ struct DeviceElementwiseImpl
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
valid = valid && false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
valid = valid && false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -254,13 +254,14 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
AComputeDataType,
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
|
||||
@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -75,13 +75,14 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ALayout,
|
||||
@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
ComputeDataType,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
LoopSched>;
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -0,0 +1,987 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeDataType = EDataType,
|
||||
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
|
||||
enable_if_t<AK1 == BK1, bool> = false>
|
||||
struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
: public DeviceGroupedGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
// TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1
|
||||
static constexpr index_t K0PerBlock = KPerBlock / AK1;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WorkspaceDataType = float;
|
||||
|
||||
// First stage GridwiseGEMM kernel.
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
WorkspaceDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
PassThrough, // CElementwiseOperation
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
AK1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ComputeDataType>;
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
static constexpr auto MakeElementwiseInputSequence()
|
||||
{
|
||||
return generate_sequence_v2(
|
||||
[&]([[maybe_unused]] auto i) constexpr {
|
||||
return Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{};
|
||||
},
|
||||
Number<NumDTensor + 1>{});
|
||||
}
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
|
||||
using CDDataTypes = decltype(concat_tuple(ck::Tuple<WorkspaceDataType*>{}, DsGridPointer{}));
|
||||
|
||||
using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence());
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
using Block2ETileMapKSplit =
|
||||
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
using GridwiseElementwise =
|
||||
GridwiseElementwise<CDGridDesc_M_N,
|
||||
ck::Tuple<EGridDesc_M_N>,
|
||||
CDDataTypes,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMap,
|
||||
CDEElementwiseOperation,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<0, 1>,
|
||||
ElementwiseInputSequence,
|
||||
ck::Sequence<CDEShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
true>;
|
||||
|
||||
// Block2CTileMap configuration parameter.
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using GemmKernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
GemmKernelArgument karg_;
|
||||
GroupedGemmBlock2ETileMap block_2_ctile_map_;
|
||||
index_t block_start_, block_end_;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(GemmKernelArgument&& karg,
|
||||
GroupedGemmBlock2ETileMap&& b2c_map,
|
||||
index_t block_start,
|
||||
index_t block_end)
|
||||
: karg_{karg},
|
||||
block_2_ctile_map_{b2c_map},
|
||||
block_start_{block_start},
|
||||
block_end_{block_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: Argument(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
DefaultKBatch)
|
||||
{
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch},
|
||||
group_count_{0},
|
||||
skipped_group_count_{0},
|
||||
grid_size_{0},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
p_Ds_{p_Ds}
|
||||
{
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
|
||||
{
|
||||
throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size");
|
||||
}
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
elementwise_c_grid_descs_m_n_.reserve(group_count_);
|
||||
elementwise_d_grid_descs_m_n_.reserve(group_count_);
|
||||
ds_grid_pointer_.reserve(group_count_);
|
||||
group_grid_size_.reserve(group_count_);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
const index_t M = gemm_descs[i].M_;
|
||||
const index_t N = gemm_descs[i].N_;
|
||||
const index_t K = gemm_descs[i].K_;
|
||||
|
||||
if(M * N * K == 0)
|
||||
{
|
||||
skipped_group_count_++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_e = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e);
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
DsGridPointer p_ds_grid;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
|
||||
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
M, N, gemm_descs[i].stride_Ds_[j]);
|
||||
});
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
group_grid_size_[i] = grid_size_grp;
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
std::array<index_t, NumDTensor> stride_ds;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
|
||||
}
|
||||
|
||||
stride_ds[j] = gemm_descs[i].stride_Ds_[j];
|
||||
});
|
||||
stride_Ds_.emplace_back(std::move(stride_ds));
|
||||
|
||||
// We first set E pointer to actual operation output, but later on
|
||||
// when workspace will be set, this will be updated to workspace memory.
|
||||
auto karg = GemmKernelArgument{type_convert<const ADataType*>(p_As[i]),
|
||||
type_convert<const BDataType*>(p_Bs[i]),
|
||||
type_convert<WorkspaceDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_e,
|
||||
m_padded,
|
||||
n_padded,
|
||||
k_padded,
|
||||
k0_padded,
|
||||
K_BATCH};
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
|
||||
|
||||
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
|
||||
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
|
||||
ds_grid_pointer_.push_back(p_ds_grid);
|
||||
}
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_ = p_Es;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set new kbatch value.
|
||||
*
|
||||
* @param[in] kbatch The new splitK parameter value.
|
||||
*/
|
||||
void UpdateKBatch(index_t kbatch)
|
||||
{
|
||||
K_BATCH = kbatch;
|
||||
grid_size_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
auto& karg = gemm_kernel_args_[i].karg_;
|
||||
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
group_grid_size_[i] = grid_size_grp;
|
||||
karg.KPadded = k_padded;
|
||||
karg.K0Padded = k0_padded;
|
||||
karg.k_batch = K_BATCH;
|
||||
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
|
||||
#if DEBUG_LOG
|
||||
index_t tiles = (block_end - block_start) / K_BATCH;
|
||||
std::cout << "block_start: " << block_start << "\n"
|
||||
<< "block_end: " << block_end << "\n"
|
||||
<< "tiles: " << tiles << std::endl
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "KPadded: " << karg.KPadded << std::endl
|
||||
<< "K0Padded: " << karg.K0Padded << std::endl
|
||||
<< "KBatch: " << karg.k_batch << std::endl
|
||||
<< "grid_size_: " << karg.KPadded << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateEPointers()
|
||||
{
|
||||
// set-up each group E pointer to it's designated workspace memory.
|
||||
WorkspaceDataType* p_workspace = reinterpret_cast<WorkspaceDataType*>(p_workspace_);
|
||||
std::size_t offset = 0;
|
||||
|
||||
for(auto& arg : gemm_kernel_args_)
|
||||
{
|
||||
arg.karg_.p_c_grid = p_workspace + offset;
|
||||
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
|
||||
offset += tiles * MPerBlock * NPerBlock;
|
||||
#if DEBUG_LOG
|
||||
std::cout << "block_start: " << arg.block_start_ << "\n"
|
||||
<< "block_end: " << arg.block_end_ << "\n"
|
||||
<< "tiles: " << tiles << "\n"
|
||||
<< "offset: " << offset << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
std::size_t size_bytes{0};
|
||||
|
||||
for(const auto& arg : gemm_kernel_args_)
|
||||
{
|
||||
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
|
||||
size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType);
|
||||
}
|
||||
return size_bytes;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSize(std::size_t group) const
|
||||
{
|
||||
const auto& arg = gemm_kernel_args_[group];
|
||||
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
|
||||
return tiles * MPerBlock * NPerBlock;
|
||||
}
|
||||
|
||||
// private:
|
||||
index_t K_BATCH;
|
||||
index_t group_count_;
|
||||
index_t skipped_group_count_;
|
||||
index_t grid_size_;
|
||||
// Pointer to device memory with GEMM kernel arguments.
|
||||
const void* p_dev_gemm_args_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds_;
|
||||
std::vector<std::array<index_t, NumDTensor>> stride_Ds_;
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
std::vector<index_t> group_grid_size_;
|
||||
|
||||
std::vector<CGridDesc_M_N> elementwise_c_grid_descs_m_n_;
|
||||
std::vector<DsGridDesc_M_N> elementwise_d_grid_descs_m_n_;
|
||||
std::vector<DsGridPointer> ds_grid_pointer_;
|
||||
std::vector<void*> e_ptrs_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using user provided device buffer for kernel
|
||||
/// arguments.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host
|
||||
/// memory).
|
||||
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
|
||||
/// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary
|
||||
/// workspace.
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
void* dev_gemm_workspace,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
|
||||
CheckArgument(arg, stream_config);
|
||||
|
||||
if(dev_gemm_args == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
if(dev_gemm_workspace == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm workspace buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(all_have_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
DispatchKernel<true>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
DispatchKernel<false>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using device buffers (for kernel arguments and
|
||||
/// for kernel auxiliary workspace) provided with an argument. The user should
|
||||
/// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see
|
||||
/// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly
|
||||
/// allocate those buffers.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host memory).
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(arg.p_dev_gemm_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
if(arg.p_workspace_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm workspace buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
|
||||
private:
|
||||
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
|
||||
{
|
||||
bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
|
||||
|
||||
{
|
||||
const auto a_grid_desc_kbatch_ak0_m_ak1 =
|
||||
GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(
|
||||
arg.gemm_kernel_args_[0].karg_.M,
|
||||
arg.gemm_kernel_args_[0].karg_.MPadded,
|
||||
arg.gemm_kernel_args_[0].karg_.K,
|
||||
arg.gemm_kernel_args_[0].karg_.StrideA,
|
||||
arg.gemm_kernel_args_[0].karg_.k_batch,
|
||||
arg.gemm_kernel_args_[0].karg_.K0Padded,
|
||||
arg.gemm_kernel_args_[0].karg_.KPadded);
|
||||
|
||||
all_have_kbatch_gt_one = arg.K_BATCH > 1;
|
||||
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
gemm_arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(gemm_arg))
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
const auto a_grid_desc_kbatch_ak0_m_ak1 =
|
||||
GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
|
||||
gemm_arg.MPadded,
|
||||
gemm_arg.K,
|
||||
gemm_arg.StrideA,
|
||||
gemm_arg.k_batch,
|
||||
gemm_arg.K0Padded,
|
||||
gemm_arg.KPadded);
|
||||
|
||||
bool not_all_have_main_k_block_loop_same =
|
||||
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
|
||||
bool not_all_have_kbatch_value_same =
|
||||
all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
|
||||
|
||||
if(not_all_have_main_k_block_loop_same)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
if(not_all_have_kbatch_value_same)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same kbatch value (=1 or >1)! "
|
||||
<< "group [" << i << "], kbatch: " << gemm_arg.k_batch
|
||||
<< ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
float DispatchKernel(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
void* dev_gemm_workspace,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
const auto gemm_kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
PassThrough>;
|
||||
|
||||
const auto elementwise_kernel = kernel_elementwise<GridwiseElementwise,
|
||||
CDGridDesc_M_N,
|
||||
ck::Tuple<EGridDesc_M_N>,
|
||||
CDDataTypes,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMap,
|
||||
CDEElementwiseOperation>;
|
||||
return LaunchKernel(gemm_kernel,
|
||||
elementwise_kernel,
|
||||
arg,
|
||||
dev_gemm_args,
|
||||
dev_gemm_workspace,
|
||||
stream_config);
|
||||
}
|
||||
|
||||
template <typename KernelFunction, typename KernelFunction2>
|
||||
float LaunchKernel(const KernelFunction& gemm_kernel,
|
||||
const KernelFunction2& elementwise_kernel,
|
||||
const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
[[maybe_unused]] void* dev_gemm_workspace,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
float time{0.f};
|
||||
|
||||
auto preprocess = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
|
||||
};
|
||||
|
||||
// GEMM kernel
|
||||
time = launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
preprocess,
|
||||
gemm_kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
PassThrough{});
|
||||
|
||||
// Elementwise kernels
|
||||
for(int i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
elementwise_kernel,
|
||||
dim3(arg.group_grid_size_[i]),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
|
||||
arg.elementwise_d_grid_descs_m_n_[i]),
|
||||
make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
|
||||
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
|
||||
arg.ds_grid_pointer_[i]),
|
||||
type_convert<EDataType*>(arg.e_ptrs_[i]),
|
||||
Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0),
|
||||
arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)},
|
||||
arg.cde_element_op_);
|
||||
}
|
||||
return time;
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
gemm_arg.Print();
|
||||
#endif // DEBUG_LOG
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc> gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op)
|
||||
{
|
||||
return Argument{p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
<< std::string(ELayout::name)[0] << ","
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
hip_check_error(hipMemcpy(p_dev_kernel_args,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
GetDeviceKernelArgSize(&arg),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(
|
||||
BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
[[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
p_arg_->UpdateEPointers();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
|
||||
}
|
||||
|
||||
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
|
||||
|
||||
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
|
||||
{
|
||||
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
|
||||
sizeof(GemmTransKernelArg);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,7 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
bool Zeroing,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -106,33 +107,63 @@ __global__ void
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
|
||||
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
nullptr,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
}
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
@@ -193,8 +224,11 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeType = ADataType,
|
||||
typename ALDSType = ComputeType,
|
||||
typename BLDSType = ComputeType>
|
||||
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using AComputeType = ComputeType;
|
||||
using BComputeType = ComputeType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeType,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ALDSType,
|
||||
BLDSType>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
@@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
if(arg.k_batch_ == 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
false,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
nullptr,
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
true,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
|
||||
// in IsSupportedArgument function
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
|
||||
// enforced in IsSupportedArgument function
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
|
||||
bool supported = true;
|
||||
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by vector
|
||||
// load size.
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by
|
||||
// vector load size.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default)
|
||||
{
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
|
||||
// thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
|
||||
// layout, thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -26,13 +26,19 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count)
|
||||
const index_t group_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
@@ -64,10 +70,16 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
gemm_desc_ptr[group_id].karg_,
|
||||
static_cast<void*>(p_shared),
|
||||
gemm_desc_ptr[group_id].block_2_ctile_map_);
|
||||
gemm_desc_ptr[group_id].block_2_ctile_map_,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
KernelArgument karg_;
|
||||
@@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.p_workspace_),
|
||||
arg.gemm_kernel_args_.size());
|
||||
arg.gemm_kernel_args_.size(),
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
};
|
||||
|
||||
if(all_have_main_k0_block_loop)
|
||||
|
||||
@@ -92,6 +92,110 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
const Y x0_converted = type_convert<Y>(x0);
|
||||
const Y x1_converted = type_convert<Y>(x1);
|
||||
y = ck::math::max(x0_converted, x1_converted);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
const Y x0_converted = type_convert<Y>(x0);
|
||||
const Y x1_converted = type_convert<Y>(x1);
|
||||
y = ck::math::min(x0_converted, x1_converted);
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 * type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0 * x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0) * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x1);
|
||||
y = x0 * x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x0 * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
};
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
// y = UnaryOp0(UnaryOp1(...(x)))
|
||||
template <typename... UnaryOpsSet>
|
||||
struct UnaryCombinedOp
|
||||
{
|
||||
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// Execute first unary op to copy data to y
|
||||
unary_ops_.At(Number<0>{})(y, x);
|
||||
|
||||
static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
|
||||
};
|
||||
|
||||
Tuple<UnaryOpsSet...> unary_ops_;
|
||||
};
|
||||
|
||||
// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
|
||||
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
|
||||
struct BinaryWithUnaryCombinedOp
|
||||
{
|
||||
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
|
||||
UnaryOp0 unary_op0,
|
||||
UnaryOp1 unary_op1)
|
||||
: binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
Y unary_x0_tmp_result;
|
||||
Y unary_x1_tmp_result;
|
||||
unary_op0_(unary_x0_tmp_result, x0);
|
||||
unary_op1_(unary_x1_tmp_result, x1);
|
||||
binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
|
||||
};
|
||||
|
||||
private:
|
||||
BinaryOp binary_op_;
|
||||
UnaryOp0 unary_op0_;
|
||||
UnaryOp1 unary_op1_;
|
||||
};
|
||||
|
||||
// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
|
||||
template <typename BinaryOp0,
|
||||
typename BinaryOp1,
|
||||
typename UnaryOp0,
|
||||
typename UnaryOp1,
|
||||
typename UnaryOp2>
|
||||
struct TrinaryWithUnaryCombinedOp
|
||||
{
|
||||
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
|
||||
BinaryOp0 binary_op1,
|
||||
UnaryOp0 unary_op0,
|
||||
UnaryOp1 unary_op1,
|
||||
UnaryOp2 unary_op2)
|
||||
: binary_op0_(binary_op0),
|
||||
binary_op1_(binary_op1),
|
||||
unary_op0_(unary_op0),
|
||||
unary_op1_(unary_op1),
|
||||
unary_op2_(unary_op2)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Y, typename X0, typename X1, typename X2>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
|
||||
{
|
||||
|
||||
Y unary_x0_tmp_result;
|
||||
Y unary_x1_tmp_result;
|
||||
Y unary_x2_tmp_result;
|
||||
unary_op0_(unary_x0_tmp_result, x0);
|
||||
unary_op1_(unary_x1_tmp_result, x1);
|
||||
unary_op2_(unary_x2_tmp_result, x2);
|
||||
binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
|
||||
binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
|
||||
};
|
||||
|
||||
private:
|
||||
BinaryOp0 binary_op0_{};
|
||||
BinaryOp1 binary_op1_{};
|
||||
UnaryOp0 unary_op0_{};
|
||||
UnaryOp1 unary_op1_{};
|
||||
UnaryOp2 unary_op2_{};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,10 +12,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
#if CK_WORKAROUND_SWDEV_383542
|
||||
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
@@ -449,11 +445,7 @@ struct FastGelu
|
||||
const float u = x * (c1 * x * x + c2);
|
||||
const float emu = __expf(u);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_383542
|
||||
y = x * __frcp_rn(1.f + emu);
|
||||
#else
|
||||
y = x * __ocml_native_recip_f32(1.f + emu);
|
||||
#endif
|
||||
y = x * ck::math::rcp(1.f + emu);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -559,6 +551,244 @@ struct TanH
|
||||
};
|
||||
};
|
||||
|
||||
struct ACos
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::acos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Neg
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::neg(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ATan
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::atan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::sin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ASinH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::asinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Cos
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::cos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ACosH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::acosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Tan
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::tan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ATanH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::atanh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SinH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::sinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::ceil(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Exp
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::exp(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct CosH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::cosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::floor(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::log(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ASin
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::asin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Rcp
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::rcp(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Swish
|
||||
{
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
@@ -118,8 +118,16 @@ struct GridwiseElementwise
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
|
||||
const index_t m1_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
|
||||
const auto thread_grid_offset =
|
||||
make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
const auto input_thread_grid_offset = generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
const auto output_thread_grid_offset = generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
// If src and dst have same vector dim, then:
|
||||
@@ -157,9 +165,9 @@ struct GridwiseElementwise
|
||||
uniform_sequence_gen_t<NumOutput, 1>,
|
||||
uniform_sequence_gen_t<NumInput, false>,
|
||||
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
input_thread_grid_offset,
|
||||
out_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
output_thread_grid_offset,
|
||||
elementwise_op};
|
||||
global_to_global_transfer.Run(
|
||||
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,7 +30,7 @@ namespace ck {
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename ComputeDataType_,
|
||||
typename AComputeDataType_,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -71,7 +71,8 @@ template <typename AsDataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
@@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using ComputeDataType =
|
||||
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using ComputeDataType = ComputeDataType_;
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ComputeDataType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
|
||||
b_block_space_size_aligned * sizeof(BComputeDataType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
Tuple<ComputeDataType>,
|
||||
Tuple<AComputeDataType>,
|
||||
decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(tie(a_block_desc_ak0_m_ak1)),
|
||||
AElementwiseOperation,
|
||||
@@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
Tuple<ComputeDataType>,
|
||||
Tuple<BComputeDataType>,
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(tie(b_block_desc_bk0_n_bk1)),
|
||||
BElementwiseOperation,
|
||||
@@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeDataType, // ComputeDataType for A
|
||||
ComputeDataType, // ComputeDataType for B
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -73,7 +73,7 @@ template <typename ADataType,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType = AComputeDataType_>
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
|
||||
@@ -31,7 +31,8 @@ namespace ck {
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeType,
|
||||
typename AComputeType,
|
||||
typename BComputeType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -71,7 +72,9 @@ template <typename ADataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer,
|
||||
typename ALDSType,
|
||||
typename BLDSType>
|
||||
struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ComputeType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(ALDSType) +
|
||||
b_block_space_size_aligned * sizeof(BLDSType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t NumDTensor_,
|
||||
typename DsDataType_,
|
||||
bool Zeroing,
|
||||
typename AGridDesc_KBatch_AK0_M_AK1,
|
||||
typename BGridDesc_KBatch_BK0_N_BK1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ComputeType,
|
||||
ALDSType,
|
||||
decltype(a_grid_desc_kbatch_ak0_m_ak1),
|
||||
decltype(a_block_desc_kbatch_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
ComputeType,
|
||||
BLDSType,
|
||||
decltype(b_grid_desc_kbatch_bk0_n_bk1),
|
||||
decltype(b_block_desc_kbatch_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType,
|
||||
ComputeType,
|
||||
ALDSType,
|
||||
BLDSType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeType,
|
||||
BComputeType>();
|
||||
|
||||
#if 1
|
||||
if(block_work_idx[I0] == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
const index_t numNThreads = NPerBlock / nThreadSize;
|
||||
const index_t numMThreads = BlockSize / numNThreads;
|
||||
const index_t mThreadSize = MPerBlock / numMThreads;
|
||||
|
||||
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
|
||||
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
|
||||
|
||||
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
EDataType,
|
||||
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
|
||||
true>
|
||||
e_thread_zero_buf;
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
EDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mThreadSize, 1, nThreadSize>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1],
|
||||
m_tid * mThreadSize,
|
||||
block_work_idx[I2],
|
||||
n_tid * nThreadSize),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
e_thread_zero_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
if(block_work_idx[I0] == 0)
|
||||
{
|
||||
atomicAdd(barrier_count_finished, 1);
|
||||
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
const index_t numNThreads = NPerBlock / nThreadSize;
|
||||
const index_t numMThreads = BlockSize / numNThreads;
|
||||
const index_t mThreadSize = MPerBlock / numMThreads;
|
||||
|
||||
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
|
||||
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
|
||||
|
||||
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
EDataType,
|
||||
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
|
||||
true>
|
||||
e_thread_zero_buf;
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
EDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mThreadSize, 1, nThreadSize>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1],
|
||||
m_tid * mThreadSize,
|
||||
block_work_idx[I2],
|
||||
n_tid * nThreadSize),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
e_thread_zero_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(barrier_count_finished, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ALDSType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BLDSType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
|
||||
@@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
@@ -951,18 +960,131 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
}
|
||||
});
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
|
||||
|
||||
if(k_id_finished_t == KBatch)
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
*barrier_count_finished = 0;
|
||||
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
|
||||
|
||||
if(k_id_finished_t == KBatch)
|
||||
{
|
||||
*barrier_count_finished = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_,
|
||||
const void* __restrict__ p_b_grid_,
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
uint32_t* barrier_count_finished,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const index_t KBatch,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto a_grid_desc_kbatch_ak0_m_ak1 =
|
||||
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
|
||||
|
||||
const auto b_grid_desc_kbatch_bk0_n_bk1 =
|
||||
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
@@ -976,7 +1098,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
uint32_t* barrier_count_finished,
|
||||
uint32_t*,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
@@ -1028,49 +1150,22 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
nullptr,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
namespace ck {
|
||||
namespace math {
|
||||
|
||||
#if CK_WORKAROUND_SWDEV_383542
|
||||
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
// math functions for the host, some are implemented by calling C++ std functions
|
||||
|
||||
static inline __host__ float abs(float x) { return std::abs(x); };
|
||||
@@ -111,6 +115,276 @@ inline __host__ double tanh<double>(double x)
|
||||
return std::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T acos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::acosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float acos<float>(float x)
|
||||
{
|
||||
return std::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double acos<double>(double x)
|
||||
{
|
||||
return std::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T neg(T x)
|
||||
{
|
||||
return ck::type_convert<T>(-(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T atan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::atanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float atan<float>(float x)
|
||||
{
|
||||
return std::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double atan<double>(double x)
|
||||
{
|
||||
return std::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T sin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::sinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float sin<float>(float x)
|
||||
{
|
||||
return std::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double sin<double>(double x)
|
||||
{
|
||||
return std::sin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T asin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::asinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float asin<float>(float x)
|
||||
{
|
||||
return std::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double asin<double>(double x)
|
||||
{
|
||||
return std::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T asinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::asinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float asinh<float>(float x)
|
||||
{
|
||||
return std::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double asinh<double>(double x)
|
||||
{
|
||||
return std::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T cos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::cosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float cos<float>(float x)
|
||||
{
|
||||
return std::cosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double cos<double>(double x)
|
||||
{
|
||||
return std::cos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T acosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::acoshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float acosh<float>(float x)
|
||||
{
|
||||
return std::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double acosh<double>(double x)
|
||||
{
|
||||
return std::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T tan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::tanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float tan<float>(float x)
|
||||
{
|
||||
return std::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double tan<double>(double x)
|
||||
{
|
||||
return std::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T atanh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::atanhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float atanh<float>(float x)
|
||||
{
|
||||
return std::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double atanh<double>(double x)
|
||||
{
|
||||
return std::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T sinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::sinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float sinh<float>(float x)
|
||||
{
|
||||
return std::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double sinh<double>(double x)
|
||||
{
|
||||
return std::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T ceil(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::ceilf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float ceil<float>(float x)
|
||||
{
|
||||
return std::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double ceil<double>(double x)
|
||||
{
|
||||
return std::ceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T cosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::coshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float cosh<float>(float x)
|
||||
{
|
||||
return std::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double cosh<double>(double x)
|
||||
{
|
||||
return std::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T floor(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::floorf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float floor<float>(float x)
|
||||
{
|
||||
return std::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double floor<double>(double x)
|
||||
{
|
||||
return std::floor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T rcp(T x)
|
||||
{
|
||||
return ck::type_convert<T>(1.f / ck::type_convert<float>(x));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T exp(T x)
|
||||
{
|
||||
@@ -282,6 +556,286 @@ inline __device__ double tanh<double>(double x)
|
||||
return ::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T acos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::acosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float acos<float>(float x)
|
||||
{
|
||||
return ::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double acos<double>(double x)
|
||||
{
|
||||
return ::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T neg(T x)
|
||||
{
|
||||
return ck::type_convert<T>(-(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t neg<half_t>(half_t x)
|
||||
{
|
||||
return __hneg(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T atan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::atanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float atan<float>(float x)
|
||||
{
|
||||
return ::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double atan<double>(double x)
|
||||
{
|
||||
return ::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T sin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::sinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float sin<float>(float x)
|
||||
{
|
||||
return ::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double sin<double>(double x)
|
||||
{
|
||||
return ::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t sin<half_t>(half_t x)
|
||||
{
|
||||
return ::hsin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T asin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::asinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float asin<float>(float x)
|
||||
{
|
||||
return ::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double asin<double>(double x)
|
||||
{
|
||||
return ::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T asinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::asinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float asinh<float>(float x)
|
||||
{
|
||||
return ::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double asinh<double>(double x)
|
||||
{
|
||||
return ::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T acosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::acoshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float acosh<float>(float x)
|
||||
{
|
||||
return ::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double acosh<double>(double x)
|
||||
{
|
||||
return ::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T tan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::tanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float tan<float>(float x)
|
||||
{
|
||||
return ::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double tan<double>(double x)
|
||||
{
|
||||
return ::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T atanh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::atanhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float atanh<float>(float x)
|
||||
{
|
||||
return ::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double atanh<double>(double x)
|
||||
{
|
||||
return ::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T sinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::sinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float sinh<float>(float x)
|
||||
{
|
||||
return ::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double sinh<double>(double x)
|
||||
{
|
||||
return ::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T ceil(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::ceilf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float ceil<float>(float x)
|
||||
{
|
||||
return ::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double ceil<double>(double x)
|
||||
{
|
||||
return ::ceil(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t ceil<half_t>(half_t x)
|
||||
{
|
||||
return ::hceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T cosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::coshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float cosh<float>(float x)
|
||||
{
|
||||
return ::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double cosh<double>(double x)
|
||||
{
|
||||
return ::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T floor(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::floorf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float floor<float>(float x)
|
||||
{
|
||||
return ::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double floor<double>(double x)
|
||||
{
|
||||
return ::floor(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t floor<half_t>(half_t x)
|
||||
{
|
||||
return ::hfloor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T rcp(T x)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_383542
|
||||
return __frcp_rn(x);
|
||||
#else
|
||||
return __ocml_native_recip_f32(x);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T exp(T x)
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
||||
|
||||
@@ -7,6 +7,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GenericAttentionMaskEnum
|
||||
{
|
||||
NO_MASK = 0,
|
||||
|
||||
// below enum could be causal, or sliding window
|
||||
MASK_FROM_TOP_LEFT = 1,
|
||||
MASK_FROM_BOTTOM_RIGHT = 2,
|
||||
|
||||
// this enum maybe not used by xformer/FA, since it's hard to
|
||||
// specify left/right window for varlen case. put it here for
|
||||
// debug purpose
|
||||
MASK_GENERIC,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
/* generic Attention Mask Coordinate
|
||||
use x(horizontal axis), y(vertical axis) to describe mask.
|
||||
@@ -188,6 +202,129 @@ struct GenericAttentionMask
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl {
|
||||
template <bool IsMasking_> struct SimplifiedMaskName;
|
||||
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
|
||||
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
// this version only have 2 variation: masking and non-masking
|
||||
// This is more friendly to codegen (e.g. need generate less kernel)
|
||||
// ... with the trade-off that may have more instruction in causal mode
|
||||
template <bool IsMasking_ = true>
|
||||
struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
|
||||
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = i_y + x; // this could be larger than x_total, but it's fine
|
||||
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
|
||||
|
||||
// TODO: no need to check begin
|
||||
return (i_x + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_x_end = i_x + TileWidth;
|
||||
index_t i_y_end = i_y + TileHeight;
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
|
||||
bool bottom_left_edge = i_y_end > (i_x + y);
|
||||
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
@@ -199,29 +336,32 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
index_t x = 0, y = 0;
|
||||
// TODO: below should all use sgpr arithmetic
|
||||
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
|
||||
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
|
||||
|
||||
if(is_top_left)
|
||||
{
|
||||
if(left_size < 0)
|
||||
left_size = y_total - 1;
|
||||
if(right_size < 0)
|
||||
right_size = x_total - 1;
|
||||
left_size = left_size < 0 ? left_size_tmp : left_size;
|
||||
right_size = right_size < 0 ? right_size_tmp : right_size;
|
||||
|
||||
x = 1 + right_size;
|
||||
y = left_size + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(left_size < 0)
|
||||
left_size = x_total - 1;
|
||||
if(right_size < 0)
|
||||
right_size = y_total - 1;
|
||||
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
|
||||
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
|
||||
|
||||
x = x_total - y_total + 1 + right_size;
|
||||
y = y_total - x_total + 1 + left_size;
|
||||
}
|
||||
index_t x = 1 + right_size + x_tmp;
|
||||
index_t y = 1 + left_size + y_tmp;
|
||||
|
||||
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -157,7 +157,9 @@ struct FmhaFwdKernel
|
||||
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t mask_y, mask_x;
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
@@ -227,8 +229,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
// QElementFunction q_element_func,
|
||||
// KElementFunction k_element_func,
|
||||
// VElementFunction v_element_func,
|
||||
@@ -285,8 +288,9 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.mask_y = mask_y;
|
||||
kargs.mask_x = mask_x;
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -324,8 +328,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
// QElementFunction q_element_func,
|
||||
// KElementFunction k_element_func,
|
||||
// VElementFunction v_element_func,
|
||||
@@ -380,8 +385,9 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.mask_y = mask_y;
|
||||
kargs.mask_x = mask_x;
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -665,7 +671,12 @@ struct FmhaFwdKernel
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k};
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
else
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaPipelineEnum
|
||||
{
|
||||
QRKSVS = 0,
|
||||
QRKSVS_ASYNC,
|
||||
QRKSVS_FP8,
|
||||
QSKSVS,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <index_t NumATensors, typename ADataType, typename BDataType, typename ElementOp>
|
||||
struct ReferenceElementwise : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const std::array<Tensor<ADataType>, NumATensors>& a_tensors,
|
||||
Tensor<BDataType>& b_tensor,
|
||||
ElementOp element_op)
|
||||
: a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<Tensor<ADataType>, NumATensors>& a_tensors_;
|
||||
Tensor<BDataType>& b_tensor_;
|
||||
ElementOp element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceElementwise::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if constexpr(NumATensors == 1)
|
||||
{
|
||||
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
|
||||
arg.element_op_(self(idx), arg.a_tensors_[0](idx));
|
||||
});
|
||||
}
|
||||
else if constexpr(NumATensors == 2)
|
||||
{
|
||||
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
|
||||
arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx));
|
||||
});
|
||||
}
|
||||
else if constexpr(NumATensors == 3)
|
||||
{
|
||||
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
|
||||
arg.element_op_(self(idx),
|
||||
arg.a_tensors_[0](idx),
|
||||
arg.a_tensors_[1](idx),
|
||||
arg.a_tensors_[2](idx));
|
||||
});
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const std::array<Tensor<ADataType>, NumATensors>& a_tensors,
|
||||
Tensor<BDataType>& b_tensor,
|
||||
ElementOp element_op)
|
||||
{
|
||||
return Argument{a_tensors, b_tensor, element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceElementwise"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,175 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// assumption: every D matrix has the same layout and the same datatype
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceGemmMultipleD : public device::BaseOperator
|
||||
{
|
||||
using DDataType = remove_cvref_t<tuple_element_t<0, DsDataType>>;
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_m_k_{a_m_k},
|
||||
b_k_n_{b_k_n},
|
||||
ds_m_n_{ds_m_n},
|
||||
c_m_n_{c_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_k_n_;
|
||||
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n_;
|
||||
Tensor<CDataType>& c_m_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmMultipleD::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
ComputeTypeA v_a = 0;
|
||||
ComputeTypeB v_b = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
|
||||
if constexpr(DsDataType::Size() == 0)
|
||||
{
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
}
|
||||
else if constexpr(DsDataType::Size() == 1)
|
||||
{
|
||||
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n));
|
||||
}
|
||||
else if constexpr(DsDataType::Size() == 2)
|
||||
{
|
||||
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n));
|
||||
}
|
||||
|
||||
arg.c_m_n_(m, n) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_m_k, b_k_n, ds_m_n, c_m_n, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmMultipleD"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,398 +12,21 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
#include "gemm_dl.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "gemm_wmma.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
#include "gemm_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -435,16 +58,137 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, int8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // DL_KERNELS
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
@@ -69,7 +69,7 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
|
||||
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP32)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,238 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -10,439 +10,18 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_backward_data_xdl.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "grouped_convolution_backward_data_wmma.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
@@ -488,9 +67,10 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
@@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
@@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
@@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
op_ptrs);
|
||||
@@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
else if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
op_ptrs);
|
||||
@@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,216 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user