diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 35c5d18d50..14b648c9f8 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -13,6 +13,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp6) +add_example_executable(example_gemm_mx_bf6 gemm_mx_bf6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf6) + add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) @@ -62,3 +65,4 @@ example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) set(FP6_MXGEMM_OPTIONS) list(APPEND FP6_MXGEMM_OPTIONS -mavx512f) example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index 57b6490eda..007c934b7e 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -8,14 +8,16 @@ Custom verification parameters: # arg2: initialization (0=constant values, 1=integer values, 2=decimal values) # arg3: time kernel (0=no, 1=yes) # arg4: verbosity (0=no info, 1=verbose info) -# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC +# arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC # arg11: KBatch +# arg12: warmup runs pre-timing +# arg13: repeat run count for timing ./bin/example_gemm_mx_fp8 1 1 0 1 ``` Custom tensor shapes: ```bash -./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1 +./bin/example_gemm_mx_fp8 1 2 1 0 256 256 512 -1 -1 -1 1 10 10 ``` Default invocation: diff --git a/example/67_gemm_microscaling/gemm_mx_bf6.cpp b/example/67_gemm_microscaling/gemm_mx_bf6.cpp new file mode 100644 index 0000000000..34810c2961 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf6.cpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf6x16_pk_t; +using BDataType = ck::bf6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 16; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 6ce10817ff..2d0585c880 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -100,8 +100,11 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl - << "arg11: KBatch" << std::endl; + << "arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC" << std::endl + << "arg11: KBatch" << std::endl + << "arg12: warmup runs pre-timing" << std::endl + << "arg13: repeat run count for timing" << std::endl; + return false; } diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 46028b79f9..33c918c997 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -550,7 +550,14 @@ struct Tensor auto dis_ = dis; // copy g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v); auto t_fn = [&]() { - if constexpr(ck::packed_size_v == 1) + // As user can pass integer distribution in dis, we must ensure that the correct + // constructor/converter is called at all times. For f4/f6/f8 types, to ensure + // correct results, we convert from float to the target type. In these cases + // integer constructors are interpreted as direct initialization of the internal + // storage with binary values instead of treating integers as subset of floats. + if constexpr(ck::is_same_v || ck::is_same_v) + return ck::type_convert(static_cast(fn(dis_(g_)))); + else if constexpr(ck::packed_size_v == 1) return ck::type_convert(fn(dis_(g_))); else if constexpr(ck::is_same_v) return ck::f4x2_pk_t{ck::type_convert( diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index efb877b3f2..8646b8393b 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1118,6 +1118,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> #endif } + template + __device__ static void Run(const bf6x16x2_t& reg_a, + const int32_t scale_a, + const bf6x16x2_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + using arg_type = int32x8_t; + arg_type arg_a{ + static_cast(reg_a.template AsType()[Number<0>{}][0]), + static_cast(reg_a.template AsType()[Number<0>{}][1]), + static_cast(reg_a.template AsType()[Number<0>{}][2]), + static_cast(reg_a.template AsType()[Number<1>{}][0]), + static_cast(reg_a.template AsType()[Number<1>{}][1]), + static_cast(reg_a.template AsType()[Number<1>{}][2]), + 0, + 0}; + arg_type arg_b{ + static_cast(reg_b.template AsType()[Number<0>{}][0]), + static_cast(reg_b.template AsType()[Number<0>{}][1]), + static_cast(reg_b.template AsType()[Number<0>{}][2]), + static_cast(reg_b.template AsType()[Number<1>{}][0]), + static_cast(reg_b.template AsType()[Number<1>{}][1]), + static_cast(reg_b.template AsType()[Number<1>{}][2]), + 0, + 0}; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_a, + arg_b, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 15b8841c39..8f5a45bdf0 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -60,6 +60,17 @@ struct f4x2_pk_t { return (x0 << 4) | (x1 & 0b00001111); } + + // Compare operator + __host__ __device__ friend bool operator==(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs) + { + return lhs.data == rhs.data; + } + + __host__ __device__ friend bool operator!=(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs) + { + return !(lhs == rhs); + } }; template diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index effe445883..ae0edb35ee 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -2254,8 +2254,9 @@ using f6x16x2_t = typename vector_type::type; using f6x32_t = typename vector_type::type; // bf6 -using bf6x16_t = typename vector_type::type; -using bf6x32_t = typename vector_type::type; +using bf6x16_t = typename vector_type::type; +using bf6x16x2_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; // e8m0 using e8m0x4_bexp_t = typename vector_type::type; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 23ab1bebb5..05e461fa63 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -2102,17 +2102,15 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 float float_array[32]; } in{x}; - union - { - bf6x32_t bf6_vector; - bf6_t bf6_array[32]; - } out{}; + using array_type = uint8_t __attribute__((ext_vector_type(32))); + array_type uint8_array; + // collect the 6-bit values into an array ck::static_for<0, 32, 1>{}([&](auto i) { - out.bf6_array[i] = utils::sat_convert_to_type(in.float_array[i] / scale); + uint8_array[static_cast(i)] = + utils::sat_convert_to_type(in.float_array[i] / scale); }); - - return out.bf6_vector; + return bf6x32_t{bf6x32_pk_t{uint8_array}}; #endif } @@ -2257,6 +2255,37 @@ inline __host__ __device__ bf6x32_pk_t type_convert(floa return static_cast(type_convert(x)); } +template <> +inline __host__ __device__ bf6x16_t type_convert(float16_t x) +{ + + union + { + float16_t v16x2[2]; + float32_t v32; + } in{{x, x}}; + + union + { + bf6x32_t v32; + bf6x16_t v16x2[2]; + } out{}; + +#if CK_USE_SR_F6_CONVERSION + out.v32 = bf6_convert_sr(in.v32); +#else + out.v32 = bf6_convert_rne(in.v32); +#endif + + return out.v16x2[0]; +} + +template <> +inline __host__ __device__ bf6x16_pk_t type_convert(float16_t x) +{ + return static_cast(type_convert(x)); +} + /** * @brief Specializes the type conversion template for converting a bf6_t value to float. * @@ -2329,6 +2358,32 @@ inline __host__ __device__ float32_t type_convert(bf6x32_t return out.float_vector; #endif } + +template <> +inline __host__ __device__ float16_t type_convert(bf6x16_t x) +{ + union + { + bf6x16_t v16x2[2]; + bf6x32_t v32; + } in{{x, x}}; + + union + { + float16_t v16x2[2]; + float32_t v32; + } out{}; + + out.v32 = type_convert(in.v32); + return out.v16x2[0]; +} + +template <> +inline __host__ __device__ float16_t type_convert(bf6x16_pk_t x) +{ + return type_convert(static_cast(x)); +} + #endif #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 7e23998f8c..32d5464e8f 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -60,6 +60,7 @@ if(GPU_TARGETS MATCHES "gfx950") add_gtest_executable(test_bf6 test_bf6.cpp) if(result EQUAL 0) + target_compile_options(test_bf6 PRIVATE -mavx512f) target_link_libraries(test_bf6 PRIVATE utility) endif() add_dependencies(test_mx_data_types test_bf6) diff --git a/test/data_type/test_bf6.cpp b/test/data_type/test_bf6.cpp index 25c01076e9..904cd302dc 100644 --- a/test/data_type/test_bf6.cpp +++ b/test/data_type/test_bf6.cpp @@ -6,6 +6,7 @@ #include "ck/utility/type_convert.hpp" #include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" +#include "ck/library/utility/device_memory.hpp" using ck::bf6_convert_rne; using ck::bf6_convert_sr; @@ -455,3 +456,57 @@ TEST(BF6, TestAllValues) } }); } + +__global__ void test_bf6_convert_rne(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + ck::float32_t float32_in(1.0f); + ck::float32_t float32_out{}; + + auto bf6x32_vec = bf6_convert_rne(float32_in); + float32_out = type_convert(bf6x32_vec); + + ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32_out[static_cast(ii)]; }); + i = N; +} + +TEST(MXBF6, DeviceBF6ConvertRNE) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_bf6_convert_rne<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + EXPECT_EQ(N, completed); + ck::static_for<0, N, 1>{}( + [&](auto ii) { EXPECT_EQ(out[static_cast(ii)], 1.0f) << "ii: " << ii << std::endl; }); + + auto bf6x32_vec_tc = ck::type_convert(ck::float32_t(1.0f)); + auto bf6x32_vec_cnstr = bf6x32_pk_t(0x0C); + + EXPECT_EQ(bf6x32_vec_tc, bf6x32_vec_cnstr); +}