mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)
* Add gemm_mx_fp8_bf8 example with row-major B
* Add more overloads of MX MFMA instructions
* Add MK_KN (RRR) tests
* Add KM_NK (CCR) tests
* Add more problem sizes to Large tests
* Add test_gemm_mx to the list of regression tests
[ROCm/composable_kernel commit: 79b0bfeb41]
This commit is contained in:
committed by
GitHub
parent
e6e687d8a7
commit
e70035cf5a
@@ -6,3 +6,6 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
|
||||
add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
|
||||
97
example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp
Normal file
97
example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::bf8_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
256, // MPerBlock
|
||||
256, // NPerBlock
|
||||
128, // KPerBlock
|
||||
16, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -797,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
|
||||
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
@@ -929,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
|
||||
@@ -1129,6 +1129,12 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
@@ -1147,6 +1153,18 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
|
||||
@@ -532,7 +532,44 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
// XXX: Note on the scale_a and scale_b parameters:
|
||||
// If compiler detects that one or both scales are constant values, it will treat that
|
||||
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
||||
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
||||
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
||||
|
||||
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
||||
// when OPSEL is set otherwise.
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -576,7 +613,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -605,7 +642,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -617,6 +654,64 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const bf8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -45,6 +45,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Row,
|
||||
Row,
|
||||
BF8,
|
||||
e8m0_bexp_t,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Col,
|
||||
Col,
|
||||
Row,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
@@ -93,11 +121,31 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, BF16>)
|
||||
{
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, BF8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
|
||||
add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, BF16>)
|
||||
{
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
# ONLY MX_KERNELS
|
||||
set(GEMM_MX_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_MX_INSTANCES
|
||||
list(APPEND GEMM_MX_INSTANCES
|
||||
device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp
|
||||
device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
|
||||
add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = f8_t;
|
||||
using BF8 = bf8_t;
|
||||
using F16 = half_t;
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto ScaleBlockSize = 32;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT)
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Row,
|
||||
Row,
|
||||
BF8,
|
||||
E8M0,
|
||||
F8,
|
||||
E8M0,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = f8_t;
|
||||
using F16 = half_t;
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto ScaleBlockSize = 32;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT)
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Col,
|
||||
Col,
|
||||
Row,
|
||||
F8,
|
||||
E8M0,
|
||||
F8,
|
||||
E8M0,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -36,6 +36,7 @@ set(REGRESSION_TESTS
|
||||
test_batchnorm_bwd_rank_4
|
||||
test_grouped_convnd_bwd_data_xdl
|
||||
test_conv_tensor_rearrange
|
||||
test_gemm_mx
|
||||
)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
|
||||
@@ -39,17 +39,49 @@ class TestGemmMX_MK_NK
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMX_MK_KN
|
||||
: public ck::test::TestGemmMX<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMX_KM_NK
|
||||
: public ck::test::TestGemmMX<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
using KernelTypes_F8_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
// ADataType, BDataType, CDataType, ScaleBlockSize
|
||||
std::tuple< F8, F8, F16, ck::Number<32> >,
|
||||
std::tuple< F8, F8, BF16, ck::Number<32> >
|
||||
#endif
|
||||
>;
|
||||
|
||||
using KernelTypes_BF8_F8_MK_KN = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
// ADataType, BDataType, CDataType, ScaleBlockSize
|
||||
std::tuple< BF8, F8, F16, ck::Number<32> >
|
||||
#endif
|
||||
>;
|
||||
|
||||
using KernelTypes_F8_KM_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
// ADataType, BDataType, CDataType, ScaleBlockSize
|
||||
std::tuple< F8, F8, BF16, ck::Number<32> >
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_F8_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_BF8_F8_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_F8_KM_NK);
|
||||
|
||||
/// A: RowMajor
|
||||
/// B: ColMajor
|
||||
/// C: RowMajor
|
||||
|
||||
TYPED_TEST(TestGemmMX_MK_NK, SmallM)
|
||||
{
|
||||
@@ -95,14 +127,151 @@ TYPED_TEST(TestGemmMX_MK_NK, Regular)
|
||||
|
||||
TYPED_TEST(TestGemmMX_MK_NK, Large)
|
||||
{
|
||||
std::vector<int> Ms{4096};
|
||||
constexpr int N = 3840;
|
||||
constexpr int K = 4096;
|
||||
std::vector<std::pair<int, int>> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}};
|
||||
|
||||
constexpr int K = 4096;
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
|
||||
for(auto test_size : test_sizes)
|
||||
{
|
||||
auto M = test_size.first;
|
||||
auto N = test_size.second;
|
||||
|
||||
const auto StrideC = N;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
/// A: RowMajor
|
||||
/// B: RowMajor
|
||||
/// C: RowMajor
|
||||
|
||||
TYPED_TEST(TestGemmMX_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMX_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMX_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{3840};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMX_MK_KN, Large)
|
||||
{
|
||||
std::vector<std::pair<int, int>> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}};
|
||||
|
||||
constexpr int K = 4096;
|
||||
constexpr int StrideA = K;
|
||||
|
||||
for(auto test_size : test_sizes)
|
||||
{
|
||||
auto M = test_size.first;
|
||||
auto N = test_size.second;
|
||||
|
||||
const auto StrideB = N;
|
||||
const auto StrideC = N;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
/// A: ColMajor
|
||||
/// B: ColMajor
|
||||
/// C: RowMajor
|
||||
|
||||
TYPED_TEST(TestGemmMX_KM_NK, SmallN)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
std::vector<int> Ns{1, 2, 3, 4, 5, 6};
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = M;
|
||||
constexpr int StrideB = K;
|
||||
|
||||
for(int N : Ns)
|
||||
{
|
||||
const auto new_N = N * 8;
|
||||
const auto StrideC = new_N;
|
||||
this->Run(M, new_N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMX_KM_NK, MidLargeN)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
std::vector<int> Ns{127, 255, 312, 799, 1573};
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = M;
|
||||
constexpr int StrideB = K;
|
||||
|
||||
for(int N : Ns)
|
||||
{
|
||||
const auto new_N = (N + 7) / 8 * 8;
|
||||
const auto StrideC = new_N;
|
||||
this->Run(M, new_N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMX_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{3840};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMX_KM_NK, Large)
|
||||
{
|
||||
std::vector<std::pair<int, int>> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}};
|
||||
|
||||
constexpr int K = 4096;
|
||||
constexpr int StrideB = K;
|
||||
|
||||
for(auto test_size : test_sizes)
|
||||
{
|
||||
auto M = test_size.first;
|
||||
auto N = test_size.second;
|
||||
|
||||
const auto StrideA = M;
|
||||
const auto StrideC = N;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ bool profile_gemm_mx_impl(int do_verification,
|
||||
break;
|
||||
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorValue(
|
||||
GeneratorTensor_3<ScaleDataType>{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user