mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
MX GEMM examples with FP8, FP16, and E8M0 scales (#2016)
* Add `scalar_type` specification for E8M0 exponent
* Specialize `nnvb_data_t_selector` for E8M0 exponent
* Remove partial specializations for `scalar_type` of `non_native_vector_base` template
* Reword command line helper string
* Create MX GEMM examples for different scales
[ROCm/composable_kernel commit: 72d888821c]
This commit is contained in:
committed by
GitHub
parent
21af4139ad
commit
75ef4c83bf
@@ -1,5 +1,10 @@
|
||||
add_custom_target(example_gemm_mx)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
|
||||
add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale)
|
||||
|
||||
@@ -2,16 +2,24 @@
|
||||
|
||||
## example_gemm_mx_fp8
|
||||
|
||||
Custom verification parameters:
|
||||
```bash
|
||||
# arg1: verification (0=no, 1=CPU)
|
||||
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
# arg2: initialization (0=constant values, 1=integer values, 2=decimal values)
|
||||
# arg3: time kernel (0=no, 1=yes)
|
||||
# arg4: verbosity (0=no info, 1=verbose info)
|
||||
# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC
|
||||
./bin/example_gemm_mx_fp8 1 1 0 1
|
||||
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
|
||||
# arg11: KBatch
|
||||
./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1
|
||||
```
|
||||
|
||||
Custom tensor shapes:
|
||||
```bash
|
||||
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
|
||||
./bin/example_gemm_mx_fp8
|
||||
./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1
|
||||
```
|
||||
|
||||
Default invocation:
|
||||
```bash
|
||||
# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0
|
||||
./bin/example_gemm_mx_fp8_fp8_scale
|
||||
```
|
||||
@@ -33,7 +33,7 @@ using ck::type_convert;
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
int do_verification = 1; // (0=no, 1=CPU)
|
||||
int init_method = 2; // (0=no init, 1=integer value, 2=decimal value)
|
||||
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
|
||||
bool time_kernel = false; // (0=no, 1=yes)
|
||||
int verbosity = 0; // (0=no info, 1=verbose info)
|
||||
};
|
||||
@@ -91,11 +91,11 @@ bool parse_cmd_args(int argc,
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
|
||||
<< "arg5 to 10: M(256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg11: KBatch" << std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -223,7 +223,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
};
|
||||
|
||||
// Hardcode scale layouts as per pipeline assumptions
|
||||
// TODO: Change default scale layouts to Col for A and Row for B
|
||||
// TODO: Allow user to specify scale layouts
|
||||
using AScaleLayout = Row;
|
||||
using BScaleLayout = Col;
|
||||
@@ -271,19 +270,31 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
break;
|
||||
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.0f, 4.0f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
|
||||
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-4.0f, 5.0f}(b_k_n);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
|
||||
|
||||
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
|
||||
{
|
||||
a_m_k_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
|
||||
}
|
||||
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
|
||||
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
|
||||
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
break;
|
||||
|
||||
default:
|
||||
|
||||
42
example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp
Normal file
42
example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 32; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
mx_vector_size>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -6,9 +6,7 @@
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
// TODO: Enable e8m0_bexp_t and FP8 scale types
|
||||
using XDataType = ck::half_t;
|
||||
// using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
42
example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp
Normal file
42
example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::f8_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 32; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
mx_vector_size>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -463,6 +463,13 @@ struct scalar_type<bf8_ocp_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<e8m0_bexp_t>
|
||||
{
|
||||
using type = e8m0_bexp_t::type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bool>
|
||||
{
|
||||
|
||||
@@ -1212,6 +1212,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
|
||||
using type = bf8_ocp_t::data_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<e8m0_bexp_t>
|
||||
{
|
||||
using type = e8m0_bexp_t::type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f6x16_pk_t>
|
||||
{
|
||||
@@ -1400,29 +1406,9 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<non_native_vector_base<T, N>>;
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
|
||||
struct scalar_type<non_native_vector_base<T, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
|
||||
|
||||
using type = typename non_native_vector_base<T, N>::data_t;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user