diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 93770684df..9e95c3e007 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,5 +1,10 @@ add_custom_target(example_gemm_mx) -add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) +add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale) +add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale) + +add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index c0a0972db6..713902588d 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -2,16 +2,24 @@ ## example_gemm_mx_fp8 +Custom verification parameters: ```bash # arg1: verification (0=no, 1=CPU) -# arg2: initialization (0=no init, 1=integer value, 2=decimal value) +# arg2: initialization (0=constant values, 1=integer values, 2=decimal values) # arg3: time kernel (0=no, 1=yes) # arg4: verbosity (0=no info, 1=verbose info) -# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC -./bin/example_gemm_mx_fp8 1 1 0 1 +# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC +# arg11: KBatch +./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1 ``` +Custom tensor shapes: ```bash -# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 -./bin/example_gemm_mx_fp8 +./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1 +``` + +Default invocation: +```bash +# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0 +./bin/example_gemm_mx_fp8_fp8_scale ``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index b8ff765174..9a05954c73 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -33,7 +33,7 @@ using ck::type_convert; struct ExecutionConfig final { int do_verification = 1; // (0=no, 1=CPU) - int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) + int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) bool time_kernel = false; // (0=no, 1=yes) int verbosity = 0; // (0=no info, 1=verbose info) }; @@ -91,11 +91,11 @@ bool parse_cmd_args(int argc, else { std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl << "arg11: KBatch" << std::endl; return false; } @@ -223,7 +223,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c }; // Hardcode scale layouts as per pipeline assumptions - // TODO: Change default scale layouts to Col for A and Row for B // TODO: Allow user to specify scale layouts using AScaleLayout = Row; using BScaleLayout = Col; @@ -271,19 +270,31 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.0f, 4.0f}(a_m_k); - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); - ck::utils::FillUniformDistributionIntegerValue{-4.0f, 5.0f}(b_k_n); - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + + if constexpr(ck::is_same_v) + { + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + } + else + { + ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); + ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); + } + break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}); + a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}); + b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); break; default: diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp new file mode 100644 index 0000000000..393f4a2ea7 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 32; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp similarity index 93% rename from example/67_gemm_microscaling/gemm_mx_fp8.cpp rename to example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp index a8ec0e14af..dd654a8f69 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp @@ -6,9 +6,7 @@ using ADataType = ck::f8_t; using BDataType = ck::f8_t; -// TODO: Enable e8m0_bexp_t and FP8 scale types using XDataType = ck::half_t; -// using XDataType = ck::e8m0_bexp_t; using CDataType = ck::half_t; using AccDataType = float; diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp new file mode 100644 index 0000000000..c42d9783be --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::f8_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 32; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 6b7aaf2162..9732739994 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -463,6 +463,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = e8m0_bexp_t::type; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 302ebd86b7..8f70962fa6 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1212,6 +1212,12 @@ struct nnvb_data_t_selector using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = e8m0_bexp_t::type; +}; + template <> struct nnvb_data_t_selector { @@ -1400,29 +1406,9 @@ struct non_native_vector_base -struct scalar_type>; - -template -struct scalar_type> +struct scalar_type> { - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - + using type = typename non_native_vector_base::data_t; static constexpr index_t vector_size = N; };