mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add examples of batched/grouped/SplitK Gemm for int8/bfp16/fp16/fp32 (#361)
* add examples into grouped/batched_gemm * adding splitK examples * fixed splitK * add bfp16 int8 example into splitK * formatting * use static_cast * added common for batched_gemm * add commons for examples of splitK/batched/grouped_gemm * return true * adjust splitK check tol * update example Co-authored-by: Chao Liu <lc.roy86@gmail.com>
This commit is contained in:
4
example/24_batched_gemm/CMakeLists.txt
Normal file
4
example/24_batched_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
|
||||
59
example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp
Normal file
59
example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
59
example/24_batched_gemm/batched_gemm_xdl_fp16.cpp
Normal file
59
example/24_batched_gemm/batched_gemm_xdl_fp16.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
58
example/24_batched_gemm/batched_gemm_xdl_fp32.cpp
Normal file
58
example/24_batched_gemm/batched_gemm_xdl_fp32.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
56
example/24_batched_gemm/batched_gemm_xdl_int8.cpp
Normal file
56
example/24_batched_gemm/batched_gemm_xdl_int8.cpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = int8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
194
example/24_batched_gemm/run_batched_gemm_example.inc
Normal file
194
example/24_batched_gemm/run_batched_gemm_example.inc
Normal file
@@ -0,0 +1,194 @@
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t stride_A = K;
|
||||
ck::index_t stride_B = K;
|
||||
ck::index_t stride_C = N;
|
||||
|
||||
ck::index_t batch_stride_A = M * K;
|
||||
ck::index_t batch_stride_B = K * N;
|
||||
ck::index_t batch_stride_C = M * N;
|
||||
|
||||
ck::index_t batch_count = 16;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, stride_A, stride_B, stride_C, batch_stride_A, batch_stride_B, batch_stride_C, batch_count] = problem_size;
|
||||
|
||||
// GEMM shape
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
|
||||
Tensor<BDataType> b_g_k_n(
|
||||
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
|
||||
|
||||
Tensor<EDataType> e_g_m_n_device_result(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
|
||||
std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
{},
|
||||
batch_stride_C,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
|
||||
sizeof(BDataType) * batch_count * K * N +
|
||||
sizeof(EDataType) * batch_count * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AccDataType, AElementOp, BElementOp, CDEElementOp>;
|
||||
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
|
||||
Tensor<EDataType> e_g_m_n_host_result(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
|
||||
|
||||
auto ref_argument = ref_batched_gemm.MakeArgument(
|
||||
a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
pass = ck::utils::check_err(
|
||||
e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c");
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
bool run_batched_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.M = 256 * (rand() % 16 + 1);
|
||||
problem_size.N = 128 * (rand() % 16 + 1);
|
||||
problem_size.K = 64 * (rand() % 16 + 1);
|
||||
|
||||
problem_size.stride_A = problem_size.K;
|
||||
problem_size.stride_B = problem_size.K;
|
||||
problem_size.stride_C = problem_size.N;
|
||||
|
||||
problem_size.batch_stride_A = problem_size.M * problem_size.K;
|
||||
problem_size.batch_stride_B = problem_size.K * problem_size.N;
|
||||
problem_size.batch_stride_C = problem_size.M * problem_size.N;
|
||||
|
||||
problem_size.batch_count = 16;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return run_batched_gemm(problem_size, config);
|
||||
}
|
||||
Reference in New Issue
Block a user