mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
MX GEMM - New GEMM pipeline for MX data types (#2059)
* Allow selection of mfma_scale instructions
* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order
* Add constexpr and synchronize return type for `get_exponent_value`
* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`
* Add support for microscaling instructions in `XdlopsGemm`
* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper
* Remove software implementation of MX GEMM
* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction
* Update README
* Updated CHANGELOG
* Remove unused static methods
[ROCm/composable_kernel commit: 7106976a72]
This commit is contained in:
committed by
GitHub
parent
1a8132e9f9
commit
5e2bd20672
@@ -1,10 +1,5 @@
|
||||
add_custom_target(example_gemm_mx)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale)
|
||||
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale)
|
||||
|
||||
@@ -10,16 +10,16 @@ Custom verification parameters:
|
||||
# arg4: verbosity (0=no info, 1=verbose info)
|
||||
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
|
||||
# arg11: KBatch
|
||||
./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1
|
||||
./bin/example_gemm_mx_fp8 1 1 0 1
|
||||
```
|
||||
|
||||
Custom tensor shapes:
|
||||
```bash
|
||||
./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1
|
||||
./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1
|
||||
```
|
||||
|
||||
Default invocation:
|
||||
```bash
|
||||
# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0
|
||||
./bin/example_gemm_mx_fp8_fp8_scale
|
||||
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
|
||||
./bin/example_gemm_mx_fp8
|
||||
```
|
||||
@@ -95,7 +95,7 @@ bool parse_cmd_args(int argc,
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
|
||||
<< "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg11: KBatch" << std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -103,7 +103,8 @@ bool parse_cmd_args(int argc,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename DeviceOpInstance,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename XDataType,
|
||||
typename CDataType,
|
||||
@@ -115,65 +116,9 @@ template <typename ADataType,
|
||||
typename CElementOp,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
ck::index_t MXVectorSize>
|
||||
ck::index_t ScaleBlockSize>
|
||||
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
static constexpr ck::index_t ScaleBlockSize = MXVectorSize;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 64;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
MXVectorSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
@@ -230,8 +175,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{}));
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
|
||||
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
|
||||
@@ -428,8 +373,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
std::size_t flop = std::size_t(2) * M * N * K +
|
||||
std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale
|
||||
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of
|
||||
// partial sums(K/ScaleBlockSize)]
|
||||
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
|
||||
@@ -445,7 +392,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
return res_verified;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename DeviceOpInstance,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename XDataType,
|
||||
typename CDataType,
|
||||
@@ -464,7 +412,8 @@ bool run_mx_gemm_example(int argc, char* argv[])
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) &&
|
||||
run_mx_gemm<ADataType,
|
||||
run_mx_gemm<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
|
||||
98
example/67_gemm_microscaling/gemm_mx_fp8.cpp
Normal file
98
example/67_gemm_microscaling/gemm_mx_fp8.cpp
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 32; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
mx_vector_size>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 32; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
mx_vector_size>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::f8_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 32; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
mx_vector_size>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
Reference in New Issue
Block a user