MX GEMM examples with FP8, FP16, and E8M0 scales (#2016)

* Add `scalar_type` specification for E8M0 exponent

* Specialize `nnvb_data_t_selector` for E8M0 exponent

* Remove partial specializations for `scalar_type` of `non_native_vector_base` template

* Reword command line helper string

* Create MX GEMM examples for different scales


[ROCm/composable_kernel commit: 72d888821c]
This commit is contained in:
Andriy Roshchenko
2025-03-25 15:33:03 -06:00
committed by GitHub
parent aaf6a0343d
commit 3f06d019ba
8 changed files with 140 additions and 41 deletions

View File

@@ -1,5 +1,10 @@
add_custom_target(example_gemm_mx)
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale)
add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale)
add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale)

View File

@@ -2,16 +2,24 @@
## example_gemm_mx_fp8
Custom verification parameters:
```bash
# arg1: verification (0=no, 1=CPU)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg2: initialization (0=constant values, 1=integer values, 2=decimal values)
# arg3: time kernel (0=no, 1=yes)
# arg4: verbosity (0=no info, 1=verbose info)
# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC
./bin/example_gemm_mx_fp8 1 1 0 1
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
# arg11: KBatch
./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1
```
Custom tensor shapes:
```bash
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
./bin/example_gemm_mx_fp8
./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1
```
Default invocation:
```bash
# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0
./bin/example_gemm_mx_fp8_fp8_scale
```

View File

@@ -33,7 +33,7 @@ using ck::type_convert;
struct ExecutionConfig final
{
int do_verification = 1; // (0=no, 1=CPU)
int init_method = 2; // (0=no init, 1=integer value, 2=decimal value)
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info)
};
@@ -91,11 +91,11 @@ bool parse_cmd_args(int argc,
else
{
std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
<< "arg5 to 10: M(256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl;
return false;
}
@@ -223,7 +223,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
};
// Hardcode scale layouts as per pipeline assumptions
// TODO: Change default scale layouts to Col for A and Row for B
// TODO: Allow user to specify scale layouts
using AScaleLayout = Row;
using BScaleLayout = Col;
@@ -271,19 +270,31 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.0f, 4.0f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-4.0f, 5.0f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
{
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
}
else
{
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
}
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
break;
default:

View File

@@ -0,0 +1,42 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t mx_vector_size = 32; // scaling block size
int main(int argc, char* argv[])
{
return run_mx_gemm_example<ADataType,
BDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
mx_vector_size>(argc, argv)
? 0
: -1;
}

View File

@@ -6,9 +6,7 @@
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
// TODO: Enable e8m0_bexp_t and FP8 scale types
using XDataType = ck::half_t;
// using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;

View File

@@ -0,0 +1,42 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::f8_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t mx_vector_size = 32; // scaling block size
int main(int argc, char* argv[])
{
return run_mx_gemm_example<ADataType,
BDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
mx_vector_size>(argc, argv)
? 0
: -1;
}

View File

@@ -463,6 +463,13 @@ struct scalar_type<bf8_ocp_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bool>
{

View File

@@ -1212,6 +1212,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
};
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
@@ -1400,29 +1406,9 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
struct scalar_type<non_native_vector_base<T, N>>
{
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
using type = typename non_native_vector_base<T, N>::data_t;
static constexpr index_t vector_size = N;
};