mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
add split-k GEMM (#59)
* add DeviceGemmSplitKXdl
* add file device_gemm_splitk_xdl.hpp
* set c matrix zero
* using atomic
* add all tuning parameter to f32 mkkn
* grid size change to 720
* add tunning parameter for NT
* add tunning parameter for TN
* add tunning parameter for TT
* add m=96tunning parameter
* add lost config
* add element wise operation
* fixed MPerBlock=96
* remove marco for slpitk swtich
* add test
* add new line at the end of device_gemm_xdl_instance.hpp
* remove step hack
* seperate split-k instance files
* add tunning parameters
* change disired grid size to parameters
* remove slice length
* add desiredgridsize parameter to ckProfiler
* add losting file device_gemm_xdl_splitk_instance.hpp
* change desired gride size to kbatch
* format
* format
* clean up
* add selection of device_instances
* clean code
* fix build issue
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
[ROCm/composable_kernel commit: 4be7f0198e]
This commit is contained in:
@@ -62,7 +62,10 @@ template <typename GridwiseGemm,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -74,7 +77,10 @@ __global__ void
|
||||
const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_element_op,
|
||||
const void CONSTANT* p_b_element_op,
|
||||
const void CONSTANT* p_c_element_op,
|
||||
const void CONSTANT* p_block_2_ctile_map)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -86,8 +92,14 @@ __global__ void
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
|
||||
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
|
||||
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
|
||||
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
|
||||
cast_pointer_to_generic_address_space(p_a_element_op));
|
||||
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
|
||||
cast_pointer_to_generic_address_space(p_b_element_op));
|
||||
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
|
||||
cast_pointer_to_generic_address_space(p_c_element_op));
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
@@ -98,7 +110,10 @@ __global__ void
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -110,6 +125,9 @@ template <index_t BlockSize,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
@@ -118,7 +136,6 @@ template <index_t BlockSize,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
@@ -126,7 +143,7 @@ template <index_t BlockSize,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
@@ -134,12 +151,10 @@ template <index_t BlockSize,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsExtraM,
|
||||
bool BBlockLdsExtraN>
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -358,6 +373,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
@@ -456,7 +474,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
@@ -487,7 +504,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
@@ -583,8 +599,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += K0PerBlock;
|
||||
} while(k_block_data_begin < (K0 - K0PerBlock));
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn =
|
||||
using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Col, Row, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn =
|
||||
using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Col, Col, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn =
|
||||
using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Row, Row, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn =
|
||||
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -44,21 +44,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Row, Col, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn =
|
||||
using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Col, Row, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn =
|
||||
using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Col, Col, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn =
|
||||
using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Row, Row, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn =
|
||||
using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
@@ -44,21 +44,10 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Row, Col, Row>(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
@@ -0,0 +1,51 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,51 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,51 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,56 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
|
||||
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -13,19 +13,19 @@ template <typename AElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
#ifndef DEVICE_GEMM_INSTANTCE_HPP
|
||||
#define DEVICE_GEMM_INSTANTCE_HPP
|
||||
|
||||
#include "device_gemm.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void add_device_gemm_instance(
|
||||
std::vector<DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -408,7 +408,8 @@ struct DeviceGemmXdl
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
|
||||
606
device_operation/include/device_gemm_xdl_splitk.hpp
Normal file
606
device_operation/include/device_gemm_xdl_splitk.hpp
Normal file
@@ -0,0 +1,606 @@
|
||||
#ifndef DEVICE_GEMM_SPLITK_XDL_HPP
|
||||
#define DEVICE_GEMM_SPLITK_XDL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
#ifndef CK_RUN_KERNEL_AND_TIME
|
||||
#define CK_RUN_KERNEL_AND_TIME 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGemmXdlSplitK
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static auto
|
||||
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
|
||||
{
|
||||
assert(KPad % (K1 * KBatch) == 0);
|
||||
|
||||
const index_t K0 = KPad / (K1 * KBatch);
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto a_grid_desc_kbatch_k0_m_k1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return a_grid_desc_kbatch_k0_m_k1;
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
|
||||
{
|
||||
assert(KPad % (K1 * KBatch) == 0);
|
||||
|
||||
const index_t K0 = KPad / (K1 * KBatch);
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_kbatch_k0_n_k1 = transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return b_grid_desc_kbatch_k0_n_k1;
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}
|
||||
|
||||
static auto GetKPad(index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
|
||||
const index_t KPad = KBatch * K0 * K1;
|
||||
return KPad;
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t k_batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
k_batch_{k_batch}
|
||||
{
|
||||
int KPad = DeviceGemmXdlSplitK::GetKPad(K, k_batch_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmXdlSplitK::MakeAGridDescriptor_KBatch_K0_M_K1(
|
||||
M, K, StrideA, k_batch_, KPad);
|
||||
b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmXdlSplitK::MakeBGridDescriptor_KBatch_K0_N_K1(
|
||||
K, N, StrideB, k_batch_, KPad);
|
||||
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
M01_,
|
||||
N01_))
|
||||
{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
|
||||
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceGemmXdlSplitK::Argument;
|
||||
|
||||
void ShowInfo(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(nrepeat > 0)
|
||||
{
|
||||
ShowInfo(arg);
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
if(kbatch > 1 || nrepeat <= 0)
|
||||
{
|
||||
hipGetErrorString(
|
||||
hipMemset(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
};
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitK::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlSplitK"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -14,14 +14,18 @@ include_directories(BEFORE
|
||||
|
||||
# device_gemm_instance
|
||||
set(DEVICE_GEMM_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
)
|
||||
|
||||
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
|
||||
@@ -83,7 +87,8 @@ set(PROFILER_SOURCE
|
||||
profile_conv_fwd.cpp
|
||||
profile_conv_fwd_bias_relu.cpp
|
||||
profile_conv_fwd_bias_relu_add.cpp
|
||||
profile_conv_fwd_bias_relu_atomic_add.cpp)
|
||||
profile_conv_fwd_bias_relu_atomic_add.cpp
|
||||
)
|
||||
add_executable(ckProfiler ${PROFILER_SOURCE})
|
||||
|
||||
target_link_libraries(ckProfiler PRIVATE host_tensor)
|
||||
|
||||
@@ -1,78 +1,29 @@
|
||||
#pragma once
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
using DeviceGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
@@ -97,7 +48,8 @@ void profile_gemm_impl(int do_verification,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC)
|
||||
int StrideC,
|
||||
int KBatch = 1)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -122,17 +74,20 @@ void profile_gemm_impl(int do_verification,
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
// set zero to c_device_buf
|
||||
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -155,9 +110,103 @@ void profile_gemm_impl(int do_verification,
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
gemm_ptrs);
|
||||
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
|
||||
is_same<CDataType, float>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
{
|
||||
@@ -184,7 +233,8 @@ void profile_gemm_impl(int do_verification,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
KBatch);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
|
||||
@@ -35,19 +35,20 @@ enum GemmDataType
|
||||
|
||||
int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 14)
|
||||
if(!(argc == 14 || argc == 15))
|
||||
{
|
||||
printf("arg1: tensor operation (gemm: GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, n] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, n] * B[n, k] = C[m, n])\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg8: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: run kernel # of times (>1)\n");
|
||||
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
printf("arg14: split k into mulitiple batch\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -65,6 +66,9 @@ int profile_gemm(int argc, char* argv[])
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
int KBatch = 1;
|
||||
if(argc == 15)
|
||||
KBatch = std::stoi(argv[14]);
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -11,8 +11,13 @@ include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
# test_magic_number_division
|
||||
set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp)
|
||||
|
||||
add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE})
|
||||
|
||||
target_link_libraries(test_magic_number_division PRIVATE host_tensor)
|
||||
|
||||
# test_split_k
|
||||
set(SPLIT_K_SOURCE split_k/main.cpp)
|
||||
add_executable(test_split_k ${SPLIT_K_SOURCE})
|
||||
target_link_libraries(test_split_k PRIVATE host_tensor)
|
||||
target_link_libraries(test_split_k PRIVATE device_gemm_instance)
|
||||
|
||||
218
test/split_k/main.cpp
Normal file
218
test/split_k/main.cpp
Normal file
@@ -0,0 +1,218 @@
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
|
||||
enum GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
using DeviceGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
template <typename T>
|
||||
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float max_diff = 1e-6;
|
||||
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 9)
|
||||
{
|
||||
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
|
||||
bool a_row_major, b_row_major, c_row_major;
|
||||
|
||||
switch(layout)
|
||||
{
|
||||
case GemmMatrixLayout::MK_KN_MN:
|
||||
a_row_major = true;
|
||||
b_row_major = true;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::MK_NK_MN:
|
||||
a_row_major = true;
|
||||
b_row_major = false;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::KM_KN_MN:
|
||||
a_row_major = false;
|
||||
b_row_major = true;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::KM_NK_MN:
|
||||
a_row_major = false;
|
||||
b_row_major = false;
|
||||
c_row_major = true;
|
||||
break;
|
||||
default: printf("not supported layout"); return 1;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, bool row_major) {
|
||||
if(row_major)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, a_row_major));
|
||||
Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, b_row_major));
|
||||
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major));
|
||||
|
||||
// init data
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
|
||||
// set zero to c_device_buf
|
||||
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);
|
||||
|
||||
host_gemm_mk_kn_mn(a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
|
||||
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
KBatch);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), 0);
|
||||
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
if(!check_out(c_m_n_host_result, c_m_n_device_result))
|
||||
{
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
success = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(success)
|
||||
{
|
||||
std::cout << "test split k : Pass" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test split k: Fail " << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user