mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
External Interface (#304)
* add client example
* clean
* clean
* reorg
* clean up profiler
* reorg
* clea
* fix profiler
* function for getinstances
* update client example
* update client example
* update client example
* update
* update example
* update Jenkins file
* update cmake
* update Jenkins
[ROCm/composable_kernel commit: aebd211c36]
This commit is contained in:
34
Jenkinsfile
vendored
34
Jenkinsfile
vendored
@@ -379,23 +379,23 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
//stage("Client App")
|
||||
//{
|
||||
// parallel
|
||||
// {
|
||||
// stage("Run Client App")
|
||||
// {
|
||||
// agent{ label rocmnode("gfx908")}
|
||||
// environment{
|
||||
// setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
|
||||
// execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """
|
||||
// }
|
||||
// steps{
|
||||
// buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
stage("Client App")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Client App")
|
||||
{
|
||||
agent{ label rocmnode("gfx908")}
|
||||
environment{
|
||||
setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc .. && make -j """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Performance Tests")
|
||||
{
|
||||
parallel
|
||||
|
||||
2
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
Normal file
2
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations)
|
||||
@@ -0,0 +1,237 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAddFastGelu;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD0 = 0;
|
||||
ck::index_t StrideD1 = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
K = std::stoi(argv[3]);
|
||||
|
||||
StrideA = std::stoi(argv[4]);
|
||||
StrideB = std::stoi(argv[5]);
|
||||
StrideD0 = std::stoi(argv[6]);
|
||||
StrideD1 = std::stoi(argv[7]);
|
||||
StrideE = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
|
||||
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
|
||||
SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) *
|
||||
f_matrix_space_size(M, N, StrideD0, D0Layout{}));
|
||||
SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) *
|
||||
f_matrix_space_size(M, N, StrideD1, D1Layout{}));
|
||||
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
|
||||
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
|
||||
get_device_gemm_add_add_fastgelu_instances<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
|
||||
d1_m_n_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 2>{StrideD0, StrideD1},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
|
||||
d1_m_n_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 2>{StrideD0, StrideD1},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
9
client_example/CMakeLists.txt
Normal file
9
client_example/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
project(ck_app)
|
||||
add_compile_options(-std=c++17)
|
||||
|
||||
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
add_subdirectory(02_gemm_add_add_fastgelu)
|
||||
32
client_example/README.md
Normal file
32
client_example/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
##
|
||||
Client application links to CK library, and therefore CK library needs to be installed before building client applications.
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm5.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build
|
||||
```bash
|
||||
mkdir -p client_example/build
|
||||
cd client_example/build
|
||||
```
|
||||
|
||||
```bash
|
||||
cmake \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
### Build client example
|
||||
```bash
|
||||
make -j
|
||||
```
|
||||
@@ -84,8 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<float, float, float, float, PassThrough, PassThrough, PassThrough>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -216,24 +221,17 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_device_f32_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
bf16_to_f32_(a_m_k, a_f32_m_k);
|
||||
bf16_to_f32_(b_k_n, b_f32_k_n);
|
||||
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_f32_m_k, b_f32_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t Batch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceBatchedGemmPtr = std::unique_ptr<
|
||||
DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
struct DeviceBatchedGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t Batch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceBatchedGemmReducePtr =
|
||||
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -111,7 +111,7 @@ __global__ void
|
||||
ignore = d_grid_desc_mblock_mperblock;
|
||||
ignore = compute_base_ptr_of_batch_;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif
|
||||
}
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
@@ -169,11 +169,11 @@ template <typename ALayout,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
: public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
: public DeviceBatchedGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -594,12 +594,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
BatchCount_(BatchCount),
|
||||
Batch_(Batch),
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
@@ -637,7 +637,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
index_t BatchCount_;
|
||||
index_t Batch_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
@@ -663,7 +663,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl;
|
||||
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
|
||||
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
@@ -692,7 +692,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
@@ -728,7 +728,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -771,7 +771,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -839,7 +839,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -856,7 +856,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount};
|
||||
Batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -878,7 +878,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
index_t Batch) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
@@ -896,7 +896,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount);
|
||||
Batch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -152,7 +152,7 @@ template <typename ADataType,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceBatchedGemmXdl
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
: public DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -339,11 +339,11 @@ struct DeviceBatchedGemmXdl
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
BatchCount_(BatchCount),
|
||||
Batch_(Batch),
|
||||
a_grid_desc_k0_m_k1_{
|
||||
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
|
||||
b_grid_desc_k0_n_k1_{
|
||||
@@ -376,7 +376,7 @@ struct DeviceBatchedGemmXdl
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
index_t BatchCount_;
|
||||
index_t Batch_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
@@ -420,7 +420,7 @@ struct DeviceBatchedGemmXdl
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
@@ -451,7 +451,7 @@ struct DeviceBatchedGemmXdl
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
@@ -485,7 +485,7 @@ struct DeviceBatchedGemmXdl
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
@@ -539,7 +539,7 @@ struct DeviceBatchedGemmXdl
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -555,7 +555,7 @@ struct DeviceBatchedGemmXdl
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
BatchCount};
|
||||
Batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -573,7 +573,7 @@ struct DeviceBatchedGemmXdl
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t BatchCount) override
|
||||
index_t Batch) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -589,7 +589,7 @@ struct DeviceBatchedGemmXdl
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
BatchCount);
|
||||
Batch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmSplitK : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceGemmSplitKPtr = std::unique_ptr<
|
||||
DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -57,7 +57,7 @@ template <typename ADataType,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGemmXdlSplitK
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -59,7 +59,7 @@ template <typename ADataType,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
|
||||
struct DeviceGemmXdlSplitKCShuffle
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -420,21 +420,22 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
add_subdirectory(src/host_tensor)
|
||||
add_subdirectory(src/tensor_operation_instance/gpu)
|
||||
add_subdirectory(src/host_tensor)
|
||||
add_subdirectory(src/utility)
|
||||
|
||||
@@ -364,13 +364,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
|
||||
{
|
||||
}
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
|
||||
|
||||
#if 1
|
||||
// FIXME: remove
|
||||
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
float check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
@@ -416,3 +411,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
return linf_error;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -62,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a;
|
||||
float v_b;
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
|
||||
arg.a_element_op_(v_a, static_cast<const float>(arg.a_g_m_k_(g, m, k)));
|
||||
arg.b_element_op_(v_b, static_cast<const float>(arg.b_g_k_n_(g, k, n)));
|
||||
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
|
||||
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b);
|
||||
}
|
||||
|
||||
float v_c;
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
|
||||
arg.c_g_m_n_(g, m, n) = v_c;
|
||||
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_gmk_gkn_gmn,
|
||||
|
||||
@@ -63,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
|
||||
arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
|
||||
arg.c_m_n_(m, n) = v_c;
|
||||
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_batched_gemm_instance {
|
||||
|
||||
using DeviceBatchedGemmNoOpPtr = ck::tensor_operation::device::DeviceBatchedGemmPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceBatchedGemmNoOpPtr>&);
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
auto get_device_batched_gemm_instances()
|
||||
{
|
||||
std::vector<DeviceBatchedGemmNoOpPtr> op_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
|
||||
is_same<CDataType, float>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, bhalf_t>::value && is_same<BDataType, bhalf_t>::value &&
|
||||
is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
|
||||
is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
} // namespace device_batched_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr<
|
||||
2,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddAddFastGelu>;
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
auto get_device_gemm_add_add_fastgelu_instances()
|
||||
{
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,286 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using DeviceGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
auto get_device_gemm_instances()
|
||||
{
|
||||
std::vector<DeviceGemmNoOpPtr> op_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
|
||||
is_same<CDataType, float>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
|
||||
is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,124 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
auto get_device_gemm_splitk_instances()
|
||||
{
|
||||
std::vector<DeviceGemmSplitKNoOpPtr> op_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
|
||||
is_same<CDataType, float>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if 1
|
||||
// FIXME: remove
|
||||
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst)
|
||||
{
|
||||
for(std::size_t i = 0; i < src.mData.size(); ++i)
|
||||
dst.mData[i] = ck::type_convert<float>(src.mData[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -6,43 +6,45 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endfunction(add_instance_library INSTANCE_NAME)
|
||||
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_splitk)
|
||||
add_subdirectory(gemm_bias2d)
|
||||
add_subdirectory(gemm_bias_relu)
|
||||
add_subdirectory(gemm_bias_relu_add)
|
||||
add_subdirectory(gemm_reduce)
|
||||
add_subdirectory(gemm_bias_add_reduce)
|
||||
add_subdirectory(gemm_add_add_fastgelu)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(conv1d_fwd)
|
||||
add_subdirectory(conv2d_fwd)
|
||||
add_subdirectory(conv3d_fwd)
|
||||
add_subdirectory(conv2d_fwd_bias_relu)
|
||||
add_subdirectory(conv2d_fwd_bias_relu_add)
|
||||
add_subdirectory(conv2d_bwd_data)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(conv2d_bwd_weight)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(gemm_add_add_fastgelu)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
add_library(device_operations STATIC
|
||||
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_splitk_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
|
||||
$<TARGET_OBJECTS:device_reduce_instance>
|
||||
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
|
||||
$<TARGET_OBJECTS:device_grouped_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
|
||||
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
|
||||
$<TARGET_OBJECTS:device_grouped_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
|
||||
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
|
||||
$<TARGET_OBJECTS:device_reduce_instance>
|
||||
)
|
||||
add_library(composablekernels::device_operations ALIAS device_operations)
|
||||
|
||||
@@ -67,8 +69,8 @@ target_include_directories(device_operations PUBLIC
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce>
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{});
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{});
|
||||
|
||||
@@ -48,7 +48,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{});
|
||||
|
||||
@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{});
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{});
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{});
|
||||
|
||||
@@ -53,7 +53,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
|
||||
|
||||
@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{});
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{});
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{});
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{});
|
||||
|
||||
@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{});
|
||||
|
||||
@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{});
|
||||
|
||||
@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{});
|
||||
|
||||
@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{});
|
||||
|
||||
@@ -51,7 +51,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{});
|
||||
|
||||
@@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -64,9 +64,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -28,14 +28,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
set(DEVICE_GEMM_SPLITK_INSTANCE_SOURCE
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
)
|
||||
|
||||
add_library(device_gemm_splitk_instance OBJECT ${DEVICE_GEMM_SPLITK_INSTANCE_SOURCE})
|
||||
|
||||
target_compile_features(device_gemm_splitk_instance PUBLIC)
|
||||
set_target_properties(device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
@@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances{});
|
||||
@@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{});
|
||||
@@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{});
|
||||
@@ -83,7 +83,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
|
||||
// >;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{});
|
||||
@@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{});
|
||||
@@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{});
|
||||
@@ -51,7 +51,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{});
|
||||
@@ -51,7 +51,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple<
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
std::vector<DeviceGemmSplitKPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{});
|
||||
@@ -6,6 +6,7 @@ include_directories(BEFORE
|
||||
set(PROFILER_SOURCE
|
||||
src/profiler.cpp
|
||||
src/profile_gemm.cpp
|
||||
src/profile_gemm_splitk.cpp
|
||||
src/profile_gemm_bias_2d.cpp
|
||||
src/profile_gemm_bias_relu.cpp
|
||||
src/profile_gemm_bias_relu_add.cpp
|
||||
@@ -27,21 +28,22 @@ add_executable(ckProfiler ${PROFILER_SOURCE})
|
||||
|
||||
target_link_libraries(ckProfiler PRIVATE host_tensor)
|
||||
target_link_libraries(ckProfiler PRIVATE conv_util)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
|
||||
@@ -7,56 +7,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_batched_gemm_instance {
|
||||
|
||||
using DeviceGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_batched_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
@@ -103,27 +64,22 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_g_m_n_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
std::unique_ptr<Tensor<float>> c_f32_g_m_n_host_result = nullptr;
|
||||
std::unique_ptr<Tensor<float>> c_f32_g_m_n_device_result = nullptr;
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
// set zero to c_device_buf
|
||||
c_g_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -135,56 +91,21 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
Tensor<float> a_f32_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
|
||||
Tensor<float> b_f32_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
|
||||
c_f32_g_m_n_host_result = std::make_unique<Tensor<float>>(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
c_f32_g_m_n_device_result = std::make_unique<Tensor<float>>(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
using ReferenceBatchedGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
bf16_to_f32_(a_g_m_k, a_f32_g_m_k);
|
||||
bf16_to_f32_(b_g_k_n, b_f32_g_k_n);
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceBatchedGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
|
||||
auto ref_argument = ref_batched_gemm.MakeArgument(
|
||||
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_batched_gemm.MakeArgument(a_f32_g_m_k,
|
||||
b_f32_g_k_n,
|
||||
*c_f32_g_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
using ReferenceBatchedGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_batched_gemm.MakeArgument(
|
||||
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
|
||||
@@ -195,172 +116,51 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_batched_gemm_instance::DeviceGemmNoOpPtr>
|
||||
gemm_ptrs;
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
get_device_batched_gemm_instances<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>();
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, bhalf_t>::value && is_same<BDataType, bhalf_t>::value &&
|
||||
is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
|
||||
is_same<CDataType, float>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
|
||||
is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_batched_gemm_instance::
|
||||
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
BatchCount);
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
BatchCount);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
@@ -376,11 +176,11 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm_name << std::endl;
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
@@ -390,20 +190,8 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
|
||||
|
||||
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
|
||||
bf16_to_f32_(c_g_m_n_device_result, *c_f32_g_m_n_device_result);
|
||||
float err = check_error(*c_f32_g_m_n_host_result, *c_f32_g_m_n_device_result);
|
||||
pass = pass && (err < 1E-6);
|
||||
}
|
||||
else
|
||||
{
|
||||
float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result);
|
||||
pass = pass && (err < 1E-6);
|
||||
}
|
||||
pass = pass &
|
||||
ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
@@ -419,13 +207,12 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this device GEMM instance does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -29,7 +29,7 @@ using Square = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Identity, Identity>;
|
||||
|
||||
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
|
||||
using DeviceBatchedGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceBatchedGemmReducePtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -37,16 +37,16 @@ using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePt
|
||||
DOutElementOps>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
@@ -204,7 +204,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr>
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceBatchedGemmReduceNoOpPtr>
|
||||
gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
int profile_convnd_fwd(int argc, char* argv[]);
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -9,6 +9,9 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
@@ -16,31 +19,6 @@
|
||||
#include "ck/library/host_tensor/host_conv.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr<
|
||||
2,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddAddFastGelu>;
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmAddAddFastGeluPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
@@ -55,18 +33,18 @@ template <typename ADataType,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
int init_method,
|
||||
bool /*do_log*/,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE)
|
||||
bool profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
int init_method,
|
||||
bool /*do_log*/,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -122,48 +100,21 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAddAddFastGeluPtr>
|
||||
device_op_ptrs;
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
|
||||
get_device_gemm_add_add_fastgelu_instances<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>();
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
device_op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
device_op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
device_op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
device_op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "found " << device_op_ptrs.size() << " instances" << std::endl;
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// run reference
|
||||
if(do_verification)
|
||||
@@ -207,7 +158,7 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
std::string best_device_op_name;
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
@@ -215,14 +166,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
bool pass = true;
|
||||
|
||||
// profile device operation instances
|
||||
for(auto& device_op_ptr : device_op_ptrs)
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = device_op_ptr->MakeArgumentPointer(
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
|
||||
d1_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -234,11 +185,11 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
auto invoker_ptr = device_op_ptr->MakeInvokerPointer();
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string device_op_name = device_op_ptr->GetTypeString();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(device_op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init E to zero before profiling a kernel
|
||||
e_device_buf.SetZero();
|
||||
@@ -256,14 +207,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << device_op_name << std::endl;
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_device_op_name = device_op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
@@ -276,14 +227,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << device_op_name << " does not support this problem" << std::endl;
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_device_op_name << std::endl;
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass ? 0 : 1;
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
@@ -12,112 +12,37 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using DeviceGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void profile_gemm_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int KBatch)
|
||||
int profile_gemm_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
@@ -134,32 +59,25 @@ void profile_gemm_impl(int do_verification,
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
// case 0: break;
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{}, num_thread);
|
||||
break;
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
// set zero to c_device_buf
|
||||
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -176,303 +94,65 @@ void profile_gemm_impl(int do_verification,
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
|
||||
// add device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
|
||||
get_device_gemm_instances<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>();
|
||||
|
||||
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
|
||||
is_same<CDataType, float>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
|
||||
is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_op = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_op.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
KBatch);
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
c_device_buf.SetZero();
|
||||
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
@@ -487,11 +167,11 @@ void profile_gemm_impl(int do_verification,
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
@@ -501,86 +181,15 @@ void profile_gemm_impl(int do_verification,
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<float> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_device_f32_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
bf16_to_f32_(a_m_k, a_f32_m_k);
|
||||
bf16_to_f32_(b_k_n, b_f32_k_n);
|
||||
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
|
||||
b_f32_k_n,
|
||||
c_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
pass =
|
||||
pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
@@ -588,8 +197,7 @@ void profile_gemm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -631,7 +239,9 @@ void profile_gemm_impl(int do_verification,
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
|
||||
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_gemm_name << std::endl;
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
256
profiler/include/profile_gemm_splitk_impl.hpp
Normal file
256
profiler/include/profile_gemm_splitk_impl.hpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool profile_gemm_splitk_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int KBatch)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device op instances
|
||||
const auto op_ptrs =
|
||||
ck::tensor_operation::device::device_gemm_instance::get_device_gemm_splitk_instances<
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device operation instance found");
|
||||
}
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass =
|
||||
pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<CDataType, float>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, half_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = bf16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = int8";
|
||||
}
|
||||
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
|
||||
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -15,10 +15,6 @@ enum struct GemmMatrixLayout
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
MK_KN_NM, // 4
|
||||
MK_NK_NM, // 5
|
||||
KM_KN_NM, // 6
|
||||
KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
@@ -31,7 +27,7 @@ enum struct GemmDataType
|
||||
|
||||
int profile_batched_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(!(argc == 15))
|
||||
if(argc != 15)
|
||||
{
|
||||
printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n");
|
||||
@@ -64,330 +60,117 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
|
||||
const int BatchCount = std::stoi(argv[14]);
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using INT8 = int8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto profile = [&](auto a_type,
|
||||
auto b_type,
|
||||
auto c_type,
|
||||
auto a_layout,
|
||||
auto b_layout,
|
||||
auto c_layout) {
|
||||
using ADataType = decltype(a_type);
|
||||
using BDataType = decltype(b_type);
|
||||
using CDataType = decltype(c_type);
|
||||
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CLayout = decltype(c_layout);
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
bool pass = ck::profiler::
|
||||
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? DefaultStrideA : StrideA,
|
||||
(StrideB < 0) ? DefaultStrideB : StrideB,
|
||||
(StrideC < 0) ? DefaultStrideC : StrideC,
|
||||
BatchCount);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
|
||||
}
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,11 +10,10 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
|
||||
#include "profiler/include/profile_convnd_fwd.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -304,7 +303,7 @@ void profile_convnd_instances(ConvDataType data_type,
|
||||
|
||||
} // namespace
|
||||
|
||||
int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
|
||||
int profile_convnd_fwd(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
|
||||
@@ -14,10 +14,6 @@ enum struct GemmMatrixLayout
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
MK_KN_NM, // 4
|
||||
MK_NK_NM, // 5
|
||||
KM_KN_NM, // 6
|
||||
KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
@@ -30,7 +26,7 @@ enum struct GemmDataType
|
||||
|
||||
int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(!(argc == 14 || argc == 15))
|
||||
if(argc != 14)
|
||||
{
|
||||
printf("arg1: tensor operation (gemm: GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
|
||||
@@ -41,9 +37,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg7: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
printf("arg14: split k into mulitiple batch\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -61,350 +56,125 @@ int profile_gemm(int argc, char* argv[])
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
int KBatch = 1;
|
||||
if(argc == 15)
|
||||
KBatch = std::stoi(argv[14]);
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using INT8 = int8_t;
|
||||
using INT32 = int32_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto profile = [&](auto a_type,
|
||||
auto b_type,
|
||||
auto acc_type,
|
||||
auto c_type,
|
||||
auto a_layout,
|
||||
auto b_layout,
|
||||
auto c_layout) {
|
||||
using ADataType = decltype(a_type);
|
||||
using BDataType = decltype(b_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
using CDataType = decltype(c_type);
|
||||
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CLayout = decltype(c_layout);
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
bool pass =
|
||||
ck::profiler::profile_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? DefaultStrideA : StrideA,
|
||||
(StrideB < 0) ? DefaultStrideB : StrideB,
|
||||
(StrideC < 0) ? DefaultStrideC : StrideC);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
|
||||
}
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,6 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
|
||||
MK_NK_MN_MN_MN, // 1
|
||||
KM_KN_MN_MN_MN, // 2
|
||||
KM_NK_MN_MN_MN, // 3
|
||||
MK_KN_NM_MN_MN, // 4
|
||||
MK_NK_NM_MN_MN, // 5
|
||||
KM_KN_NM_MN_MN, // 6
|
||||
KM_NK_NM_MN_MN, // 7
|
||||
};
|
||||
|
||||
enum struct MatrixDataType
|
||||
@@ -101,17 +97,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
|
||||
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
return ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>(
|
||||
bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
@@ -124,6 +120,8 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
|
||||
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
|
||||
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
|
||||
(StrideE < 0) ? DefaultStrideE : StrideE);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
|
||||
@@ -149,6 +147,6 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
148
profiler/src/profile_gemm_splitk.cpp
Normal file
148
profiler/src/profile_gemm_splitk.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/include/profile_gemm_splitk_impl.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
};
|
||||
|
||||
int profile_gemm_splitk(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 15)
|
||||
{
|
||||
printf("arg1: tensor operation (gemm_splitk: Split-K GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
printf("arg14: split k into mulitiple batch\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const int init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
|
||||
const int M = std::stoi(argv[8]);
|
||||
const int N = std::stoi(argv[9]);
|
||||
const int K = std::stoi(argv[10]);
|
||||
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
const int KBatch = std::stoi(argv[14]);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto profile = [&](auto a_type,
|
||||
auto b_type,
|
||||
auto acc_type,
|
||||
auto c_type,
|
||||
auto a_layout,
|
||||
auto b_layout,
|
||||
auto c_layout) {
|
||||
using ADataType = decltype(a_type);
|
||||
using BDataType = decltype(b_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
using CDataType = decltype(c_type);
|
||||
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CLayout = decltype(c_layout);
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? DefaultStrideA : StrideA,
|
||||
(StrideB < 0) ? DefaultStrideB : StrideB,
|
||||
(StrideC < 0) ? DefaultStrideC : StrideC,
|
||||
KBatch);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
@@ -1,49 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include "profiler/include/profile_convnd_fwd.hpp"
|
||||
|
||||
int profile_gemm(int, char*[]);
|
||||
int profile_gemm_splitk(int, char*[]);
|
||||
int profile_gemm_bias_2d(int, char*[]);
|
||||
int profile_gemm_bias_relu(int, char*[]);
|
||||
int profile_gemm_bias_relu_add(int, char*[]);
|
||||
int profile_gemm_reduce(int, char*[]);
|
||||
int profile_gemm_bias_add_reduce(int, char*[]);
|
||||
int profile_gemm_add_add_fastgelu(int, char*[]);
|
||||
int profile_gemm_reduce(int, char*[]);
|
||||
int profile_batched_gemm(int, char*[]);
|
||||
int profile_batched_gemm_reduce(int, char*[]);
|
||||
int profile_grouped_gemm(int, char*[]);
|
||||
int profile_conv_fwd(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu_add(int, char*[]);
|
||||
int profile_convnd_fwd(int argc, char* argv[]);
|
||||
int profile_convnd_bwd_data(int, char*[], int);
|
||||
int profile_reduce(int, char*[]);
|
||||
int profile_conv_bwd_weight(int, char*[]);
|
||||
int profile_batched_gemm_reduce(int, char*[]);
|
||||
int profile_gemm_add_add_fastgelu(int, char*[]);
|
||||
int profile_reduce(int, char*[]);
|
||||
|
||||
static void print_helper_message()
|
||||
{
|
||||
// clang-format off
|
||||
printf("arg1: tensor operation (gemm: GEMM\n"
|
||||
" gemm_bias_2d: GEMM+Bias(2D)\n"
|
||||
" gemm_bias_relu: GEMM+Bias+ReLU\n"
|
||||
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
|
||||
" gemm_reduce: GEMM+Reduce\n"
|
||||
" grouped_gemm: Grouped GEMM\n"
|
||||
" conv_fwd: ForwardConvolution\n"
|
||||
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
|
||||
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
|
||||
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
|
||||
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
|
||||
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
|
||||
" reduce: Reduce\n"
|
||||
" conv2d_bwd_weight: Backward Weight Convolution 2d\n"
|
||||
" gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n");
|
||||
printf("arg1: tensor operation (gemm: GEMM\n"
|
||||
" gemm_splitk: Split-K GEMM\n"
|
||||
" gemm_bias_2d: GEMM+Bias(2D)\n"
|
||||
" gemm_bias_relu: GEMM+Bias+ReLU\n"
|
||||
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
|
||||
" gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n"
|
||||
" gemm_reduce: GEMM+Reduce\n"
|
||||
" batched_gemm: Batched GEMM\n"
|
||||
" grouped_gemm: Grouped GEMM\n"
|
||||
" conv_fwd: ForwardConvolution\n"
|
||||
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
|
||||
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
|
||||
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
|
||||
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
|
||||
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
|
||||
" conv2d_bwd_weight: Backward Weight Convolution 2d\n"
|
||||
" reduce: Reduce\n");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -60,6 +58,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_splitk") == 0)
|
||||
{
|
||||
return profile_gemm_splitk(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
|
||||
{
|
||||
return profile_gemm_bias_2d(argc, argv);
|
||||
@@ -94,7 +96,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd") == 0)
|
||||
{
|
||||
return ck::profiler::profile_convnd_fwd(argc, argv);
|
||||
return profile_convnd_fwd(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
|
||||
{
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef BATCHED_GEMM_UTILS_HPP
|
||||
#define BATCHED_GEMM_UTILS_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace batched_gemm_util {
|
||||
|
||||
struct GemmParams
|
||||
{
|
||||
GemmParams()
|
||||
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
|
||||
{
|
||||
}
|
||||
|
||||
ck::index_t M;
|
||||
ck::index_t N;
|
||||
ck::index_t K;
|
||||
|
||||
ck::index_t StrideA;
|
||||
ck::index_t StrideB;
|
||||
ck::index_t StrideC;
|
||||
|
||||
float alpha;
|
||||
float beta;
|
||||
};
|
||||
|
||||
template <typename BatchedGemmInstance,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
void RunHostBatchedGemm(const Tensor<ADataType>& A,
|
||||
const Tensor<BDataType>& B,
|
||||
Tensor<CDataType>& C,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
auto ref_batched_gemm = BatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_batched_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
template <typename DeviceGemmPtr,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
void RunDeviceBatchedGemm(DeviceGemmPtr& batched_gemm_ptr,
|
||||
const ck::batched_gemm_util::GemmParams& params,
|
||||
const Tensor<ADataType>& A,
|
||||
const Tensor<BDataType>& B,
|
||||
Tensor<CDataType>& C,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
|
||||
DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
|
||||
DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(A.mData.data());
|
||||
b_g_k_n_device_buf.ToDevice(B.mData.data());
|
||||
|
||||
const auto batch_count = A.mDesc.GetLengths()[0];
|
||||
auto invoker_ptr = batched_gemm_ptr->MakeInvokerPointer();
|
||||
auto argument_ptr = batched_gemm_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_g_m_n_device_buf.GetDeviceBuffer()),
|
||||
params.M,
|
||||
params.N,
|
||||
params.K,
|
||||
params.StrideA,
|
||||
params.StrideB,
|
||||
params.StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
batch_count);
|
||||
|
||||
if(!batched_gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get());
|
||||
c_g_m_n_device_buf.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
} // namespace batched_gemm_util
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -214,6 +214,11 @@ struct TestGemm
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
@@ -234,121 +239,5 @@ struct TestGemm
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceGemmPtr_,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct TestGemmBF16
|
||||
{
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// use fp32 host kernel to verify bf16 device kernel
|
||||
Tensor<BF16> a_m_k_bf16(
|
||||
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
||||
Tensor<BF16> b_k_n_bf16(
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<BF16> c_m_n_device_bf16(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
|
||||
Tensor<float> a_m_k_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
||||
Tensor<float> b_k_n_fp32(
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<float> c_m_n_host_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_device_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
|
||||
a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
|
||||
b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
|
||||
|
||||
bf16_to_f32_(a_m_k_bf16, a_m_k_fp32);
|
||||
bf16_to_f32_(b_k_n_bf16, b_k_n_fp32);
|
||||
|
||||
return std::make_tuple(a_m_k_bf16,
|
||||
b_k_n_bf16,
|
||||
c_m_n_device_bf16,
|
||||
a_m_k_fp32,
|
||||
b_k_n_fp32,
|
||||
c_m_n_host_fp32,
|
||||
c_m_n_device_fp32);
|
||||
}
|
||||
|
||||
auto operator()(DeviceGemmPtr_& gemmPtr)
|
||||
{
|
||||
// Arrange
|
||||
ck::gemm_util::GemmParams params;
|
||||
params.M = 1024;
|
||||
params.N = 1024;
|
||||
params.K = 1024;
|
||||
params.StrideA = 1024;
|
||||
params.StrideB = 1024;
|
||||
params.StrideC = 1024;
|
||||
|
||||
auto host_tensors = PrepareGemmTensorBF16(params);
|
||||
const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors);
|
||||
const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors);
|
||||
Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors);
|
||||
Tensor<float>& a_fp32 = std::get<3>(host_tensors);
|
||||
Tensor<float>& b_fp32 = std::get<4>(host_tensors);
|
||||
Tensor<float>& c_host_fp32 = std::get<5>(host_tensors);
|
||||
Tensor<float>& c_device_fp32 = std::get<6>(host_tensors);
|
||||
|
||||
auto a_element_op = AElementwiseOperation{};
|
||||
auto b_element_op = BElementwiseOperation{};
|
||||
auto c_element_op = CElementwiseOperation{};
|
||||
|
||||
// use fp32 host kernel to verify bf16 device kernel
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>;
|
||||
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
|
||||
a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
// Act
|
||||
ck::gemm_util::RunDeviceGEMM(gemmPtr,
|
||||
params,
|
||||
a_bf16,
|
||||
b_bf16,
|
||||
c_device_bf16,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
bf16_to_f32_(c_device_bf16, c_device_fp32);
|
||||
|
||||
// Assert
|
||||
bool res = ck::utils::check_err(
|
||||
c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
return res;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace gemm_util
|
||||
} // namespace ck
|
||||
|
||||
@@ -47,6 +47,11 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
|
||||
int main()
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using RowMajor = ck::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
@@ -58,13 +63,17 @@ int main()
|
||||
|
||||
for(auto& gemmPtr : gemmPtrs)
|
||||
{
|
||||
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
|
||||
ColumnMajor,
|
||||
RowMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ColumnMajor,
|
||||
RowMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
}
|
||||
|
||||
gemmPtrs.clear();
|
||||
@@ -73,13 +82,17 @@ int main()
|
||||
|
||||
for(auto& gemmPtr : gemmPtrs)
|
||||
{
|
||||
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
|
||||
ColumnMajor,
|
||||
ColumnMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ColumnMajor,
|
||||
ColumnMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
}
|
||||
|
||||
gemmPtrs.clear();
|
||||
@@ -88,13 +101,17 @@ int main()
|
||||
|
||||
for(auto& gemmPtr : gemmPtrs)
|
||||
{
|
||||
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
|
||||
RowMajor,
|
||||
RowMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
RowMajor,
|
||||
RowMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
}
|
||||
|
||||
gemmPtrs.clear();
|
||||
@@ -103,13 +120,17 @@ int main()
|
||||
|
||||
for(auto& gemmPtr : gemmPtrs)
|
||||
{
|
||||
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
|
||||
RowMajor,
|
||||
ColumnMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
RowMajor,
|
||||
ColumnMajor,
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
}
|
||||
|
||||
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
@@ -38,10 +38,12 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
#if 0
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
#endif
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
@@ -69,8 +71,10 @@ int main()
|
||||
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
|
||||
|
||||
@@ -92,8 +96,10 @@ int main()
|
||||
gemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
|
||||
|
||||
@@ -115,8 +121,10 @@ int main()
|
||||
gemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
|
||||
|
||||
@@ -138,8 +146,10 @@ int main()
|
||||
gemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
|
||||
@@ -38,10 +38,12 @@ void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNo
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
#if 0
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
#endif
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
@@ -67,8 +69,10 @@ int main()
|
||||
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
|
||||
|
||||
@@ -90,8 +94,10 @@ int main()
|
||||
gemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
|
||||
|
||||
@@ -113,8 +119,10 @@ int main()
|
||||
gemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
|
||||
|
||||
@@ -136,8 +144,10 @@ int main()
|
||||
gemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
|
||||
#if 0
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
|
||||
#endif
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE host_tensor)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_gemm.hpp"
|
||||
@@ -28,20 +27,24 @@ enum struct GemmMatrixLayout
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
using DeviceGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmSplitKNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
@@ -150,7 +153,7 @@ int test_gemm(const gemmArgs& args)
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
|
||||
std::vector<DeviceGemmSplitKNoOpPtr> gemm_ptrs;
|
||||
|
||||
if(args.layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user