Add f4 ckProfiler

This commit is contained in:
Ding, Yi
2025-05-16 06:26:18 +00:00
parent b8fcc7e3b1
commit dc30e7d025
15 changed files with 795 additions and 427 deletions

View File

@@ -500,7 +500,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
std::size_t num_btype = sizeof(ADataType) * M * K / pack_size_v<ADataType> +
sizeof(BDataType) * K * N / pack_size_v<BDataType> +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;

View File

@@ -497,33 +497,36 @@ struct scalar_type<bool>
};
template <typename T>
struct element_type
struct pack_info
{
private:
static constexpr auto get_element_type()
static constexpr auto get_pack_info()
{
using U = remove_cvref_t<T>;
if constexpr(std::is_same_v<U, pk_i4_t>)
return int4_t{};
return ck::Tuple<ck::Number<2>, int4_t>{};
else if constexpr(std::is_same_v<U, f4x2_pk_t>)
return f4_t{};
return ck::Tuple<ck::Number<2>, f4_t>{};
else if constexpr(std::is_same_v<U, f6x16_pk_t>)
return f6_t{};
return ck::Tuple<ck::Number<16>, f6_t>{};
else if constexpr(std::is_same_v<U, bf6x16_pk_t>)
return bf6_t{};
return ck::Tuple<ck::Number<16>, bf6_t>{};
else if constexpr(std::is_same_v<U, f6x32_pk_t>)
return f6_t{};
return ck::Tuple<ck::Number<32>, f6_t>{};
else if constexpr(std::is_same_v<U, bf6x32_pk_t>)
return bf6_t{};
return ck::Tuple<ck::Number<32>, bf6_t>{};
else
return T{};
return ck::Tuple<ck::Number<1>, T>{};
}
public:
using type = decltype(get_element_type());
using element_type = remove_cvref_t<decltype(get_pack_info().At(ck::Number<1>{}))>;
static constexpr auto pack_size = static_cast<index_t>(get_pack_info().At(ck::Number<0>{}));
};
template <typename T>
using element_type_t = typename element_type<T>::type;
using element_type_t = typename pack_info<T>::element_type;
template <typename T>
inline constexpr index_t pack_size_v = pack_info<T>::pack_size;
#if defined(_WIN32)
using int64_t = long long;

View File

@@ -23,6 +23,7 @@ using I32 = int32_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using I4 = ck::pk_i4_t;
using F4 = ck::f4x2_pk_t;
using Empty_Tuple = ck::Tuple<>;

View File

@@ -45,6 +45,20 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Col,
Row,
F4,
I32,
F4,
I32,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Row,
@@ -127,6 +141,11 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
}
else if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
is_same_v<CDataType, F16>)
{
add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)

View File

@@ -6,6 +6,7 @@ list(APPEND GEMM_MX_INSTANCES
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp
device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp
device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp
)
@@ -13,6 +14,7 @@ set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES})

View File

@@ -40,6 +40,7 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple<
#if 0 // TODO: Fix v1
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
@@ -51,6 +52,7 @@ using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple<
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F4 = f4x2_pk_t;
using F16 = half_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#############################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 256, 32, 32, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Col,
Row,
F4,
E8M0PK,
F4,
E8M0PK,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -39,6 +39,7 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple<
#if 0 // TODO: Fix v1
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
@@ -52,6 +53,7 @@ using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple<
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -39,6 +39,7 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple<
#if 0 // TODO: Fix v1
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
@@ -50,6 +51,7 @@ using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple<
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -39,6 +39,7 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple<
#if 0 // TODO: Fix v1
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
@@ -50,6 +51,7 @@ using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple<
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -0,0 +1,498 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace ck {
namespace profiler {
#if 1
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
#endif
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
int ScaleBlockSize>
bool profile_gemm_mx_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
using XDataType = e8m0_bexp_t;
using AScaleLayout = Row;
using BScaleLayout = Col;
using XPackedDataType = // TODO: use int32 for all
conditional_t<is_same<ADataType, ck::f4x2_pk_t>::value, int32_t, e8m0_bexp_t>;
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
return HostTensorDescriptor({row, col}, {stride, 1});
else
return HostTensorDescriptor({row, col}, {1, stride});
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
return static_cast<ck::index_t>(col);
else
return static_cast<ck::index_t>(row);
}
else
return static_cast<ck::index_t>(stride);
};
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
// scales for A and B
Tensor<XDataType> a_m_k_scale(
f_host_tensor_descriptor(M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_k_n_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
// shuffled scales for A and B
Tensor<XDataType> a_shuffled_scale(
f_host_tensor_descriptor(M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_shuffled_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::size_t total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else
return ck::type_convert<BDataType>(x);
};
switch(init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(0.5f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
if(do_log)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5}); // Z[-4,4]
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5}); // Z[-4,4]
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1]
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
break;
}
#if 1
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(
a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<BLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
#endif
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
if(do_log > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize());
if(do_log > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
if(do_log > 0)
std::cout << "Done." << std::endl;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
XPackedDataType,
BDataType,
XPackedDataType,
CDataType,
ScaleBlockSize,
AElementOp,
BElementOp,
CElementOp>;
std::cout << "finding op instances..." << std::endl;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm< //
ADataType,
BDataType,
CDataType,
float, // AccDataType
XDataType,
AElementOp,
BElementOp,
CElementOp,
float, // ComputeTypeA
float // ComputeTypeB
>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
b_k_n_scale,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
bool pass = true;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideC,
kbatch_curr,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_log)
{
if(init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(0, 12));
pass = pass & (std::abs(expected - computed) <= 0.0f);
std::cout << "\nExpected vs Computed: " << expected << " vs "
<< computed << ((pass) ? " (PASSED!)" : " (FAILED!)")
<< std::endl
<< std::endl;
}
else
{
if constexpr(is_same_v<ADataType, ck::f8_t> ||
is_same_v<ADataType, ck::bf8_t>)
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",")
<< "\n";
else
std::cout << "A: WIP PRINT PACKED TYPE\n";
LogRangeAsType<float>(std::cout << "a_scale : ", a_m_k_scale.mData, ",")
<< "\n";
if constexpr(is_same_v<BDataType, ck::f8_t> ||
is_same_v<BDataType, ck::bf8_t>)
LogRangeAsType<float>(std::cout << "b : ", b_k_n.mData, ",")
<< "\n";
else
std::cout << "B: WIP PRINT PACKED TYPE\n";
LogRangeAsType<float>(std::cout << "b_scale: ", b_k_n_scale.mData, ",")
<< "\n";
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< "\n";
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) +
// scaling of partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop =
std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
// TODO: fp6?
std::size_t num_btype = sizeof(ADataType) * M * K / pack_size_v<ADataType> +
sizeof(BDataType) * K * N / pack_size_v<BDataType> +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -7,6 +7,8 @@ endif()
if (CK_PROFILER_INSTANCE_FILTER STREQUAL "")
set(CK_PROFILER_INSTANCE_FILTER ${CK_PROFILER_OP_FILTER})
endif()
message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
set(PROFILER_OPS
profile_gemm.cpp
@@ -60,6 +62,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
@@ -111,7 +116,7 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
# flags to compress the library
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding --offload-compress flag for ${PROFILER_EXECUTABLE}")
message(STATUS "Adding --offload-compress flag for ${PROFILER_EXECUTABLE}")
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE --offload-compress)
endif()
@@ -163,6 +168,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
endif()
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)

View File

@@ -0,0 +1,137 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_mx_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F4_F4_F16, // 0
};
#define OP_NAME "gemm_mx"
#define OP_DESC "GEMM_mx"
int profile_gemm_mx(int argc, char* argv[])
{
if(argc != 11 && argc != 14 && argc != 18)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: f4->f16)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("optional:\n");
printf("arg14: number of kbatch (default 1)\n");
printf("arg15: number of warm-up cycles (default 1)\n");
printf("arg16: number of iterations (default 10)\n");
printf("arg17: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
int arg_index = 2;
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[arg_index++]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[arg_index++]));
const bool do_verification = std::stoi(argv[arg_index++]);
const int init_method = std::stoi(argv[arg_index++]);
const bool do_log = std::stoi(argv[arg_index++]);
const bool time_kernel = std::stoi(argv[arg_index++]);
const int M = std::stoi(argv[arg_index++]);
const int N = std::stoi(argv[arg_index++]);
const int K = std::stoi(argv[arg_index++]);
int StrideA = -1, StrideB = -1, StrideC = -1;
if(argc > arg_index)
{
StrideA = std::stoi(argv[arg_index++]);
StrideB = std::stoi(argv[arg_index++]);
StrideC = std::stoi(argv[arg_index++]);
}
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
int KBatch = 1;
if(argc > arg_index)
{
KBatch = std::stoi(argv[arg_index++]);
n_warmup = std::stoi(argv[arg_index++]);
n_iter = std::stoi(argv[arg_index++]);
rotating = std::stoull(argv[arg_index++]) * 1024 * 1024;
}
using F16 = ck::half_t;
using F4 = ck::f4x2_pk_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::
profile_gemm_mx_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout, 32>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch,
n_warmup,
n_iter,
rotating);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F4{}, F4{}, F16{}, Row{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_mx);

View File

@@ -18,6 +18,7 @@
#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "profiler/profile_gemm_mx_impl.hpp"
namespace ck {
namespace test {
@@ -27,401 +28,6 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
} // namespace
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
int ScaleBlockSize>
bool profile_gemm_mx_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
using ScaleDataType = e8m0_bexp_t;
using AScaleLayout = Row;
using BScaleLayout = Col;
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<ck::index_t>(col);
}
else
{
return static_cast<ck::index_t>(row);
}
}
else
return static_cast<ck::index_t>(stride);
};
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<ScaleDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<ScaleDataType> b_k_n_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::size_t total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
switch(init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<ScaleDataType>{ck::type_convert<ScaleDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
ck::utils::FillConstant<ScaleDataType>{ck::type_convert<ScaleDataType>(1.0f)}(b_k_n_scale);
if(do_log)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5}); // Z[-4,4]
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5}); // Z[-4,4]
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<ScaleDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<ScaleDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_3<ScaleDataType>{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1]
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_3<ScaleDataType>{powf(2.0f, -125.0f), 1.0f});
break;
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
if(do_log > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(ScaleDataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(ScaleDataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
if(do_log > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
if(do_log > 0)
std::cout << "Done." << std::endl;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
ScaleDataType,
BDataType,
ScaleDataType,
CDataType,
ScaleBlockSize,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMXGemm<ADataType,
BDataType,
CDataType,
float, // AccDataType
ScaleDataType,
AElementOp,
BElementOp,
CElementOp,
float, // ComputeTypeA
float // ComputeTypeB
>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
b_k_n_scale,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideC,
kbatch_curr,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_log)
{
if(init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(0, 12));
pass = pass & (std::abs(expected - computed) <= 0.0f);
std::cout << "\nExpected vs Computed: " << expected << " vs "
<< computed << ((pass) ? " (PASSED!)" : " (FAILED!)")
<< std::endl
<< std::endl;
}
else
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "a_scale : ", a_m_k_scale.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b_scale: ", b_k_n_scale.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) +
// scaling of partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop =
std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(ScaleDataType) * (M * K + K * N) / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass;
}
template <typename Tuple>
class TestGemmMX : public testing::Test
{
@@ -471,25 +77,25 @@ class TestGemmMX : public testing::Test
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::test::profile_gemm_mx_impl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleBlockSize>(verify_,
init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
kbatch,
n_warmup,
n_iter);
bool pass = ck::profiler::profile_gemm_mx_impl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleBlockSize>(verify_,
init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};