mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add ckProfiler
This commit is contained in:
@@ -23,6 +23,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
static constexpr bool PermuteB = true;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 128;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
@@ -30,35 +32,31 @@ using DeviceGemmV2Instance =
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
#if 0
|
||||
64,
|
||||
16, 16,
|
||||
256, 8, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 1,
|
||||
1, 1, S<1, 16, 1, 4>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, CDataType, CDataType, false, PermuteB>;
|
||||
|
||||
[[maybe_unused]] static int KPerBlock = 256;
|
||||
#else
|
||||
128,
|
||||
16, 32,
|
||||
128, 8, 32,
|
||||
16, 128,
|
||||
KPerBlock, 8, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
1, 4,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 8>, 4,
|
||||
#else
|
||||
128,
|
||||
16, 64,
|
||||
KPerBlock, 8, 32,
|
||||
16, 16,
|
||||
1, 2,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 8>, 4,
|
||||
#endif
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, CDataType, CDataType, false, PermuteB>;
|
||||
|
||||
[[maybe_unused]]static int KPerBlock = 128;
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
|
||||
@@ -36,6 +36,9 @@ struct DeviceGemmV2 : public BaseOperator
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
|
||||
@@ -637,6 +637,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
index_t GetKPerBlock() override { return KPerBlock; }
|
||||
|
||||
bool GetPermuteB() override { return PermuteB; }
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
|
||||
@@ -1175,6 +1175,34 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, f8_t>::value)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
constexpr index_t pack_size = 8;
|
||||
|
||||
static_assert(SrcScalarPerVector % pack_size == 0, "");
|
||||
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack8{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
|
||||
@@ -12,7 +12,7 @@ using half_t = _Float16;
|
||||
using int4_t = _BitInt(4);
|
||||
using f8_t = _BitInt(8);
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
using pk_i4_t = unsigned char;
|
||||
using pk_i4_t = uint8_t;
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
|
||||
@@ -22,6 +22,7 @@ using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
|
||||
@@ -166,11 +166,17 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -810,6 +816,17 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -97,6 +97,8 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
|
||||
device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I4 = pk_i4_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
#if 0
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Compute friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -65,6 +65,7 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -86,9 +87,13 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
break;
|
||||
default:
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -100,11 +105,10 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmV2<ALayout,
|
||||
BLayout,
|
||||
@@ -152,6 +156,87 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
const int KPerBlock = op_ptr->GetKPerBlock();
|
||||
|
||||
if(op_ptr->GetPermuteB())
|
||||
{
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int i4x2 = b_k_n_permute(j + k * 2, i);
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b_device_buf.ToDevice(b_k_n_permute.mData.data());
|
||||
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
|
||||
|
||||
if(KBatch > 0)
|
||||
@@ -238,7 +323,15 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N / BPackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -1,87 +1,87 @@
|
||||
# ckProfiler
|
||||
set(PROFILER_SOURCES
|
||||
profiler.cpp
|
||||
profile_gemm.cpp
|
||||
profile_reduce.cpp
|
||||
profile_groupnorm_bwd_data.cpp
|
||||
profile_groupnorm_fwd.cpp
|
||||
profile_layernorm_bwd_data.cpp
|
||||
profile_layernorm_bwd_gamma_beta.cpp
|
||||
profile_groupnorm_bwd_gamma_beta.cpp
|
||||
profile_layernorm_fwd.cpp
|
||||
profile_max_pool2d_fwd.cpp
|
||||
profile_pool3d_fwd.cpp
|
||||
profile_avg_pool3d_bwd.cpp
|
||||
profile_max_pool3d_bwd.cpp
|
||||
profile_avg_pool2d_bwd.cpp
|
||||
profile_max_pool2d_bwd.cpp
|
||||
profile_softmax.cpp
|
||||
profile_batchnorm_fwd.cpp
|
||||
profile_batchnorm_bwd.cpp
|
||||
profile_batchnorm_infer.cpp
|
||||
profile_conv_tensor_rearrange.cpp
|
||||
profile_transpose.cpp
|
||||
profile_permute_scale.cpp
|
||||
#profile_gemm.cpp
|
||||
#profile_reduce.cpp
|
||||
#profile_groupnorm_bwd_data.cpp
|
||||
#profile_groupnorm_fwd.cpp
|
||||
#profile_layernorm_bwd_data.cpp
|
||||
#profile_layernorm_bwd_gamma_beta.cpp
|
||||
#profile_groupnorm_bwd_gamma_beta.cpp
|
||||
#profile_layernorm_fwd.cpp
|
||||
#profile_max_pool2d_fwd.cpp
|
||||
#profile_pool3d_fwd.cpp
|
||||
#profile_avg_pool3d_bwd.cpp
|
||||
#profile_max_pool3d_bwd.cpp
|
||||
#profile_avg_pool2d_bwd.cpp
|
||||
#profile_max_pool2d_bwd.cpp
|
||||
#profile_softmax.cpp
|
||||
#profile_batchnorm_fwd.cpp
|
||||
#profile_batchnorm_bwd.cpp
|
||||
#profile_batchnorm_infer.cpp
|
||||
#profile_conv_tensor_rearrange.cpp
|
||||
#profile_transpose.cpp
|
||||
#profile_permute_scale.cpp
|
||||
)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp)
|
||||
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
#if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
# list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
|
||||
# endif()
|
||||
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
# endif()
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
|
||||
# endif()
|
||||
# list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp)
|
||||
#
|
||||
#endif()
|
||||
#
|
||||
#if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
|
||||
# endif()
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
|
||||
#endif()
|
||||
#
|
||||
#if(DL_KERNELS)
|
||||
# list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
|
||||
#endif()
|
||||
|
||||
set(PROFILER_EXECUTABLE ckProfiler)
|
||||
|
||||
@@ -94,87 +94,87 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
|
||||
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
|
||||
#
|
||||
#if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
|
||||
# endif()
|
||||
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance)
|
||||
# endif()
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
|
||||
# endif()
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
#endif()
|
||||
#
|
||||
#if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
||||
# endif()
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
|
||||
#endif()
|
||||
#
|
||||
#if(DL_KERNELS)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
|
||||
#endif()
|
||||
|
||||
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
|
||||
|
||||
@@ -27,6 +27,7 @@ enum struct GemmDataType
|
||||
F16_F8_F16, // 5
|
||||
F16_F16_F16_F8, // 6
|
||||
F8_F8_BF16, // 7
|
||||
F16_I4_F16, // 8
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_universal"
|
||||
@@ -39,7 +40,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
|
||||
"f16->f8; 7: f8->bf16, "
|
||||
"comp f8)\n");
|
||||
"comp f8; 8: f16@i4)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
@@ -87,6 +88,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -187,6 +189,10 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user