Merge branch 'ck_tile/refactor' into ck_tile/elementwise

This commit is contained in:
rocking
2024-04-09 02:37:42 +08:00
committed by GitHub
216 changed files with 9200 additions and 4325 deletions

View File

@@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
@@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this example as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
@@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)

View File

@@ -1,20 +1,3 @@
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
set(target 1)
endif()
endforeach()
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)

View File

@@ -1,8 +1 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)

View File

@@ -1,29 +1,20 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
set(target 1)
endif()
endforeach()
set(gpu_list "")
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)

View File

@@ -1,19 +1,11 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::bf8_t;
using AccDataType = float;
using CShuffleDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using AComputeType = ck::f8_t;
using BComputeType = ck::bf8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeType,
BComputeType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }

View File

@@ -1,25 +1,17 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl)
add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)

View File

@@ -1,12 +1,3 @@
# dlops
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
# xdlops
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)

View File

@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)

View File

@@ -0,0 +1,394 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <ck/utility/data_type.hpp>
#include <ck/utility/tuple.hpp>
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType, DDataType>;
using EDataType = F32;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout, DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr int NumDMatrices = 2;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<std::vector<ck::index_t>> stride_Ds;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 128;
bool time_kernel = true;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
std::vector<std::array<const void*, NumDMatrices>> p_Ds = {};
gemm_descs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
p_Ds.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<std::array<Tensor<DDataType>, NumDMatrices>> d_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_result_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_result_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDMatrices> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens);
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices +
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
}
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
problem_size.stride_Ds[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
gemm.SetKBatchSize(argument, config.k_batch);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = argument.gemm_kernel_args_[i].karg_;
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{}));
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(),
c_device_result_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
d_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
return pass;
}
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc < 11)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg10: k_batch (> 0)\n"
<< "... setting default values." << std::endl;
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[10]);
problem_size.Ms = argToIntArray(argv[4]);
problem_size.Ns = argToIntArray(argv[5]);
problem_size.Ks = argToIntArray(argv[6]);
problem_size.stride_As = argToIntArray(argv[7]);
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
problem_size.group_count = problem_size.Ms.size();
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -36,7 +36,7 @@ using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128);
problem_size.Ms.push_back(128 + rand() % 128);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(1024);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);

View File

@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F8;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
struct ProblemSize final

View File

@@ -1,48 +1,41 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
add_example_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
add_example_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
set(target 1)
endif()
endforeach()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()

View File

@@ -1,14 +1,7 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
endif()
set(target 1)
endif()
endforeach()
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
endif()
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
if(result EQUAL 0)

View File

@@ -1,29 +1,15 @@
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
set(target 1)
endif()
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
set(target 1)
endif()
endforeach()
add_custom_target(example_grouped_conv_bwd_weight_dl)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)

View File

@@ -1,12 +1,4 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)

View File

@@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear)
# FP32
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
# FP64
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
# FP16
add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
# BF16
add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp)
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp)
add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
add_dependencies(example_contraction example_contraction_scale)
add_dependencies(example_contraction example_contraction_bilinear)
add_example_dependencies(example_contraction example_contraction_scale)
add_example_dependencies(example_contraction example_contraction_bilinear)

View File

@@ -1,5 +1,2 @@
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)

View File

@@ -1,40 +1,23 @@
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() # USE_BITINT_EXTENSION_INT4
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() # USE_BITINT_EXTENSION_INT4
set(target 1)
endif()
endforeach()
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)

View File

@@ -1,17 +1,9 @@
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)

View File

@@ -1,11 +1,9 @@
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
endif()
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)

View File

@@ -1,32 +1,23 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_splitK_gemm_xdl)
add_custom_target(example_splitK_gemm_xdl)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()
set(target 1)
endif()
endforeach()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()

View File

@@ -1,27 +1,10 @@
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data)
add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
set(target 1)
endif()
endforeach()
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
set(target 1)
endif()
endforeach()
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)

View File

@@ -1,24 +1,17 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
# Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
# Conv + bias + relu perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
# Conv + bias + relu perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
# Conv + bias + tanh perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
# Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
# Conv + bias + relu perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
# Conv + bias + relu perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
# Conv + bias + tanh perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)

View File

@@ -1,17 +1,9 @@
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list2 gfx908 gfx90a)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
set(target 1)
endif()
endforeach()
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)

View File

@@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu
add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp)
add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp)
add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp)
add_example_executable(example_elementwise_permute elementwise_permute.cpp)
if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942"))
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp)

View File

@@ -0,0 +1,140 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using BinaryAdd = ck::tensor_operation::element_wise::Add;
// B = alpha * A0 * A0 + beta * A1 * A1
using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise::
BinaryWithUnaryCombinedOp<BinaryAdd, UnaryScaleSquare, UnaryScaleSquare>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType, ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
BinaryAddUnaryScaleSquare, // ElementwiseOp
4, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8, 8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
static_cast<int>(nchw[2] * nchw[3]),
static_cast<int>(nchw[3]),
1};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 2> as = {Tensor<ADataType>(ab_lengths, ab_strides),
Tensor<ADataType>(ab_lengths, ab_strides)};
Tensor<ADataType>& a0 = as[0];
Tensor<ADataType>& a1 = as[1];
Tensor<BDataType> b(ab_lengths, ab_strides);
float alpha = 3.f;
float beta = 2.f;
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0.mData.data());
a1_device_buf.ToDevice(a1.mData.data());
std::array<const void*, 2> inputs = {a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths,
{ab_strides, ab_strides},
{ab_strides},
inputs,
output,
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
std::cout << "B (nchw): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
Tensor<BDataType> host_b(ab_lengths, ab_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as,
host_b,
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance =
ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
{
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{
auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_ndhwc(n, d, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -51,32 +39,7 @@ int main()
std::vector<std::size_t> ncdhw = {16, 8, 8, 8, 8};
std::vector<std::size_t> ndhwc = {16, 8, 8, 8, 8};
Tensor<ADataType> a(ncdhw);
Tensor<BDataType> b(ndhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths;
/**std::array<ck::index_t, 5> a_strides = {
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[2] * ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[4]),
1};
std::array<ck::index_t, 5> b_strides = {
static_cast<int>(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]),
static_cast<int>(ndhwc[2] * ndhwc[3] * ndhwc[4]),
1,
static_cast<int>(ndhwc[3] * ndhwc[4]),
static_cast<int>(ndhwc[4])};**/
std::array<ck::index_t, 5> a_strides = {
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
@@ -93,6 +56,20 @@ int main()
1};
ck::ranges::copy(ncdhw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -126,10 +103,16 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(ndhwc);
host_elementwise4D(host_b, a, PassThrough{});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance =
ck::Sequence<4>, // InScalarPerVectorSeq
ck::Sequence<4>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
{
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{
auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_ndhwc(n, d, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -59,10 +47,13 @@ int main()
const int W = 5;
const int D = 16;
std::vector<std::size_t> ncdhw = {N, C, D, H, W};
std::vector<std::size_t> ndhwc = {N, D, H, W, C};
Tensor<ADataType> a(ncdhw);
Tensor<BDataType> b(ndhwc);
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
@@ -74,10 +65,6 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -94,11 +81,12 @@ int main()
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] *
ab_lengths[3] * ab_lengths[4];
std::size_t num_btype =
sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
(sizeof(ADataType) + sizeof(BDataType)) *
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -111,10 +99,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(ndhwc);
host_elementwise4D(host_b, a, PassThrough{});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -55,18 +44,6 @@ int main()
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
@@ -77,9 +54,22 @@ int main()
1,
static_cast<int>(nhwc[2] * nhwc[3]),
static_cast<int>(nhwc[3])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -111,10 +101,16 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
@@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance =
ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
const std::vector<std::size_t>& shape_nchw,
Functor functor)
{
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -54,13 +40,16 @@ int main()
const int N = 120;
const int C = 128;
const int H = 32;
const int W = 1024;
const int W = 32;
std::vector<std::size_t> nchw = {N, C, H, W};
std::vector<std::size_t> nhwc = {N, H, W, C};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
@@ -72,11 +61,6 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
@@ -94,10 +78,11 @@ int main()
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t flop =
std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) *
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -110,11 +95,16 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{});
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -6,9 +6,11 @@
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -21,11 +23,14 @@ using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1];
std::size_t H = A_nchw.mDesc.GetLengths()[2];
std::size_t W = A_nchw.mDesc.GetLengths()[3];
for(std::size_t w = 0; w < W; ++w)
for(std::size_t h = 0; h < H; ++h)
for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n)
{
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
}
}
int main()
{
bool do_verification = true;
@@ -60,8 +48,21 @@ int main()
std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 1.f;
auto i = 0;
std::mt19937 gen(11939);
@@ -84,22 +85,14 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -113,11 +106,10 @@ int main()
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype =
(2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
@@ -129,10 +121,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -5,9 +5,11 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -20,11 +22,14 @@ using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -55,18 +47,6 @@ int main()
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
@@ -80,9 +60,29 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -112,10 +112,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -5,9 +5,11 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -20,11 +22,14 @@ using F32 = float;
using ADataType = F32;
using BDataType = F32;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1];
std::size_t H = A_nchw.mDesc.GetLengths()[2];
std::size_t W = A_nchw.mDesc.GetLengths()[3];
for(std::size_t w = 0; w < W; ++w)
for(std::size_t h = 0; h < H; ++h)
for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n)
{
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {5, 4, 2, 3};
std::vector<std::size_t> nhwc = {5, 2, 3, 4};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 1.f;
auto i = 0;
@@ -84,22 +86,14 @@ int main()
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -129,10 +123,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -5,9 +5,11 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -20,11 +22,14 @@ using F32 = float;
using ADataType = F32;
using BDataType = F32;
using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
UnaryOp, // UnaryOp
UnaryScaleSquare, // UnaryScaleSquare
4, // NumDim
256, // BlockSize
128, // M0PerBlock
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
@@ -55,18 +47,6 @@ int main()
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
@@ -80,9 +60,28 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(ab_lengths, b_strides);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
auto argument =
broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
@@ -112,10 +111,17 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, UnaryOp{scale});
Tensor<BDataType> host_b(ab_lengths, b_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
ref_invoker.Run(ref_argument);
b_device_buf.FromDevice(b.mData.data());
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}

View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using UnaryScale = ck::tensor_operation::element_wise::Scale;
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
using UnaryScaleSquare =
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
using BinaryAdd = ck::tensor_operation::element_wise::Add;
// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2
using TrinaryAddUnaryScaleSquare =
ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp<BinaryAdd,
BinaryAdd,
UnaryScaleSquare,
UnaryScaleSquare,
UnaryScaleSquare>;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ADataType, ADataType, ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
TrinaryAddUnaryScaleSquare, // ElementwiseOp
4, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
int main()
{
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
static_cast<int>(nchw[2] * nchw[3]),
static_cast<int>(nchw[3]),
1};
ck::ranges::copy(nchw, ab_lengths.begin());
std::array<Tensor<ADataType>, 3> as = {Tensor<ADataType>(ab_lengths, ab_strides),
Tensor<ADataType>(ab_lengths, ab_strides),
Tensor<ADataType>(ab_lengths, ab_strides)};
Tensor<ADataType>& a0 = as[0];
Tensor<ADataType>& a1 = as[1];
Tensor<ADataType>& a2 = as[2];
Tensor<BDataType> b(ab_lengths, ab_strides);
float alpha = 3.f;
float beta = 2.f;
float gamma = 4.f;
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a2.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0.mData.data());
a1_device_buf.ToDevice(a1.mData.data());
a2_device_buf.ToDevice(a2.mData.data());
std::array<const void*, 3> inputs = {a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
a2_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths,
{ab_strides, ab_strides, ab_strides},
{ab_strides},
inputs,
output,
TrinaryAddUnaryScaleSquare{
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
std::cout << "A2 (nchw): " << a2.mDesc << std::endl;
std::cout << "B (nchw): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
Tensor<BDataType> host_b(ab_lengths, ab_strides);
using ReferenceElementwiseInstance = ck::tensor_operation::host::
ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(
as,
host_b,
TrinaryAddUnaryScaleSquare{
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
ref_invoker.Run(ref_argument);
const double threshold = std::pow(2, -10) * 2;
b_device_buf.FromDevice(b.mData.data());
pass &= ck::utils::check_err(
b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold);
}
return pass ? 0 : 1;
}

View File

@@ -1,8 +1 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp)

View File

@@ -1,15 +1,7 @@
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_im2col_col2im)
add_custom_target(example_im2col_col2im)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
set(target 1)
endif()
endforeach()
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)

View File

@@ -1,8 +1 @@
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)

View File

@@ -1,8 +1 @@
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)

View File

@@ -2,16 +2,9 @@ add_subdirectory(binary)
add_subdirectory(multi_AB)
add_subdirectory(unary)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_activ_xdl)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
set(target 1)
endif()
endforeach()
add_custom_target(example_convnd_activ_xdl)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)

View File

@@ -1,5 +1,3 @@
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
endif()
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)

View File

@@ -5,6 +5,12 @@ include_directories(BEFORE
add_custom_target(examples)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(FILE_NAME)
add_dependencies(EXAMPLE_NAME FILE_NAME)
endif()
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
set(result 1)
@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif()
endforeach()
endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif()
endforeach()
endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})

View File

@@ -56,12 +56,13 @@ auto create_args(int argc, char* argv[])
.insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left local-attn with left right size\n"
"'b:l,r', bottom-r local-attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size\n")
.insert(
"mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left sliding window attn with left right size\n"
"'b:l,r', bottom-r sliding window attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
@@ -357,44 +358,45 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
return fmha_fwd_args<FmhaDefaultElementFunctions>{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.y,
mask.x,
ck_tile::identity{},
ck_tile::identity{}};
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
ck_tile::identity{},
ck_tile::identity{}};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
@@ -508,12 +510,32 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
}
else
{
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
}
if(lse)
{

View File

@@ -104,169 +104,6 @@ struct FmhaMasks
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
#if 0
// internal API, don't use this directly
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t nhead_k,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t max_seqlen_q,
float scale,
bool i_perm,
bool o_perm,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x)
{
constexpr bool is_v_rowmajor =
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias'
/// are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if constexpr(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_k : nhead_k * seqlen_k;
}();
const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if constexpr(is_v_rowmajor)
return i_perm ? seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_k : seqlen_k;
}();
const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k);
const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1);
const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1);
const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
nhead / nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
mask_y,
mask_x);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
nhead / nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask_y,
mask_x);
}
}();
dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// This is the args from caller to underneath API, different from the kernel
struct fmha_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t batch;
ck_tile::index_t nhead;
ck_tile::index_t nhead_k;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t max_seqlen_q;
float scale;
bool i_perm;
bool o_perm;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
};
#endif
// runtime args, some will passed to karg, some will used to compute grids/blocks
template <typename ElementFunctions>
struct fmha_fwd_args
@@ -306,8 +143,9 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// typename ElementFunctions::QElementFunction q_element_func;
// typename ElementFunctions::KElementFunction k_element_func;
// typename ElementFunctions::VElementFunction v_element_func;
@@ -350,8 +188,9 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.window_size_left,
args.window_size_right,
args.mask_type,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
@@ -392,8 +231,9 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.batch_stride_bias,
args.batch_stride_lse,
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.window_size_left,
args.window_size_right,
args.mask_type,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
@@ -420,6 +260,7 @@ template <ck_tile::index_t HDim_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
bool kHasBias_,
bool kStoreLse_,
@@ -439,6 +280,7 @@ struct fmha_fwd_traits_
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;

View File

@@ -24,6 +24,16 @@ DTYPE_BITS = {
"bf8" : 8
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
@@ -49,13 +59,16 @@ ELEMENT_FUNC_MAP = {
"no" : "FmhaDefaultElementFunctions",
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
MASKS = ["no", "causal", "generic"]
DIRECTIONS = ["fwd"]
GEN_DIR = "" # in Cmake, have to generate files in same folder
@@ -130,7 +143,8 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
@@ -168,17 +182,40 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
"""
MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
def get_mask_map(mask : str):
if mask == "generic":
return MASK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_check_map(mask : str):
if mask == "generic":
return MASK_CHECK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
@@ -212,14 +249,19 @@ class FmhaFwdApiTrait:
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
@@ -228,7 +270,7 @@ class FmhaFwdApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@@ -239,7 +281,7 @@ class FmhaFwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@@ -270,13 +312,17 @@ class FmhaFwdPipeline:
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias == 't' : n += '_bias'
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
return n
class FmhaFwdApiPool:
def __init__(self):
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
@@ -298,8 +344,9 @@ class FmhaFwdApiPool:
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask],
F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
@@ -341,6 +388,7 @@ class FmhaFwdKernel:
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
F_element_func : str
@property
@@ -369,8 +417,9 @@ class FmhaFwdKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy ,
F_mask = MASK_MAP[self.F_pipeline.F_mask],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
@@ -426,14 +475,17 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask))
@@ -446,16 +498,19 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask))
if receipt == 1:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool()
api_pool = FmhaFwdApiPool(mask_impl)
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
@@ -474,6 +529,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_element_func=element_func)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
@@ -489,24 +545,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None:
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
api_pool, kernels = get_blobs(kernel_filter)
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_api(api_pool, output_dir)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None:
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
with file_path.open('a') as f:
_, kernels = get_blobs(kernel_filter)
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
@@ -535,8 +591,26 @@ if __name__ == "__main__":
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim"
)
args = parser.parse_args()
if args.list_blobs is not None:
list_blobs(args.list_blobs, args.filter)
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
else:
write_blobs(args.output_dir, args.filter)
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)

View File

@@ -9,11 +9,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
causal_top_left,
causal_bottom_right,
mask_top_left,
mask_bottom_right,
window_generic,
};
@@ -21,18 +22,19 @@ struct mask_info
{
mask_enum type;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::causal_top_left)
os << "tl";
else if(type == mask_enum::causal_bottom_right)
os << "br";
else if(type == mask_enum::mask_top_left)
os << "tl(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "br(" << left << ":" << right << ")";
else
{
os << "g(" << y << "/" << x << ")";
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
@@ -57,22 +59,30 @@ struct mask_info
// TODO: some validation
if(t == "t")
{
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
@@ -84,15 +94,19 @@ struct mask_info
{
// should be 0, 1, 2
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::causal_top_left)
if(tmp.type == mask_enum::mask_top_left)
{
tmp.y = seqlen_q;
tmp.x = 1;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(tmp.type == mask_enum::causal_bottom_right)
else if(tmp.type == mask_enum::mask_bottom_right)
{
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
}
return tmp;

View File

@@ -23,7 +23,8 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done
done