mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
MX GEMM - Add FP6 GEMM Test (#2488)
* Add F6 GEMM MX Test
* Add BF6 GEMM MX Test
[ROCm/composable_kernel commit: 25b359d630]
This commit is contained in:
committed by
GitHub
parent
a024e11036
commit
dfb07b8240
@@ -24,6 +24,8 @@ using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F6 = ck::f6x16_pk_t;
|
||||
using BF6 = ck::bf6x16_pk_t;
|
||||
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
using E8M0PK = int32_t;
|
||||
|
||||
@@ -87,6 +87,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
F6,
|
||||
E8M0PK,
|
||||
F6,
|
||||
E8M0PK,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
BF6,
|
||||
E8M0PK,
|
||||
BF6,
|
||||
E8M0PK,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
@@ -130,6 +158,8 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
// Row-Col-Row -- one of the two currently supported layouts, another one is
|
||||
// Row-MFMA-Row
|
||||
if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
@@ -147,6 +177,16 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, F6> && is_same_v<BDataType, F6> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, BF6> && is_same_v<BDataType, BF6> &&
|
||||
is_same_v<CDataType, BF16>)
|
||||
{
|
||||
add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
set(GEMM_MX_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_MX_INSTANCES
|
||||
device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp
|
||||
@@ -11,6 +13,8 @@ list(APPEND GEMM_MX_INSTANCES
|
||||
)
|
||||
|
||||
|
||||
set_source_files_properties(device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
using E8M0PK = int32_t;
|
||||
using BF6 = ck::bf6x16_pk_t;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto ScaleBlockSize = 32;
|
||||
static constexpr auto KPerBlock = 256 / ck::packed_size_v<BF6>; // 256 bf6 = 16 bf6x16_pk_t
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
std::nullptr_t
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
BF6,
|
||||
E8M0PK,
|
||||
BF6,
|
||||
E8M0PK,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
using E8M0PK = int32_t;
|
||||
using F6 = ck::f6x16_pk_t;
|
||||
using BF6 = ck::bf6x16_pk_t;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto ScaleBlockSize = 32;
|
||||
static constexpr auto KPerBlock = 256 / ck::packed_size_v<F6>; // 256 f6 = 16 f6x16_pk_t
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
std::nullptr_t
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
F6,
|
||||
E8M0PK,
|
||||
F6,
|
||||
E8M0PK,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -216,12 +216,20 @@ bool profile_gemm_mx_impl(int do_verification,
|
||||
auto a_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<ADataType>(ck::float2_t(x));
|
||||
else if constexpr(ck::packed_size_v<ADataType> == 32)
|
||||
return ck::type_convert<ADataType>(ck::float32_t(x));
|
||||
else if constexpr(ck::packed_size_v<ADataType> == 16)
|
||||
return ck::type_convert<ADataType>(ck::float16_t(x));
|
||||
else
|
||||
return ck::type_convert<ADataType>(x);
|
||||
};
|
||||
auto b_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<BDataType>(ck::float2_t(x));
|
||||
else if constexpr(ck::packed_size_v<BDataType> == 32)
|
||||
return ck::type_convert<BDataType>(ck::float32_t(x));
|
||||
else if constexpr(ck::packed_size_v<BDataType> == 16)
|
||||
return ck::type_convert<BDataType>(ck::float16_t(x));
|
||||
else
|
||||
return ck::type_convert<BDataType>(x);
|
||||
};
|
||||
@@ -247,15 +255,17 @@ bool profile_gemm_mx_impl(int do_verification,
|
||||
|
||||
case 1:
|
||||
|
||||
a_m_k.GenerateTensorDistr(int_distr{-4, 5}); // Z[-4,4]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-4, 5}); // Z[-4,4]
|
||||
a_m_k.GenerateTensorDistr(
|
||||
int_distr{-4, 4}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-4,4]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-4, 4}); // Z[-4,4]
|
||||
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
break;
|
||||
|
||||
default:
|
||||
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
a_m_k.GenerateTensorDistr(
|
||||
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr)));
|
||||
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
add_gtest_executable(test_gemm_mx test_gemm_mx.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_compile_options(test_gemm_mx PRIVATE -mavx512f)
|
||||
target_link_libraries(test_gemm_mx PRIVATE utility device_gemm_mx_instance)
|
||||
endif()
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F6 = ck::f6_t;
|
||||
using BF6 = ck::bf6_t;
|
||||
using F6 = ck::f6x16_pk_t;
|
||||
using BF6 = ck::bf6x16_pk_t;
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
@@ -58,7 +58,9 @@ using KernelTypes_MK_NK = ::testing::Types<
|
||||
std::tuple< F8, F8, F16, ck::Number<32> >,
|
||||
std::tuple< F8, F8, BF16, ck::Number<32> >,
|
||||
#endif
|
||||
std::tuple< F4, F4, F16, ck::Number<32> >
|
||||
std::tuple< F4, F4, F16, ck::Number<32> >,
|
||||
std::tuple< F6, F6, F16, ck::Number<32> >,
|
||||
std::tuple< BF6, BF6, BF16, ck::Number<32> >
|
||||
>;
|
||||
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestGemmMX : public testing::Test
|
||||
const int StrideB,
|
||||
const int StrideC,
|
||||
int kbatch = 1,
|
||||
int n_warmup = 1,
|
||||
int n_warmup = 10,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = ck::profiler::profile_gemm_mx_impl<ADataType,
|
||||
|
||||
Reference in New Issue
Block a user