diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index db162fe444..63a2aea0b3 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count); - d_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size std::size_t flop = 0, num_btype = 0; diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 71417ce7bf..257636d956 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -360,10 +360,9 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) + if constexpr(ck::is_packed_type_v>) { - return (mDesc.GetElementSpaceSize() + 1) / 2; + return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v>; } else { @@ -516,69 +515,31 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mDesc.GetOffsetFromMultiIndex(is...) / 2; - } - else - { - return mDesc.GetOffsetFromMultiIndex(is...); - } + return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v>; } template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; } template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; } T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } typename Data::iterator begin() { return mData.begin(); } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 785f74a3c0..f48ba49bbf 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -67,6 +67,18 @@ struct GeneratorTensor_1 return ck::type_convert(value); } }; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf8_t operator()(Is...) + { + return ck::type_convert(value); + } +}; #endif template <> @@ -93,6 +105,38 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + template <> struct GeneratorTensor_1 { @@ -132,6 +176,44 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + template <> struct GeneratorTensor_2 { @@ -342,6 +424,46 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + template struct GeneratorTensor_4 { @@ -360,6 +482,69 @@ struct GeneratorTensor_4 } }; +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f4x2_pk_t operator()(Is...) + { + float fp32_tmp0 = distribution(generator); + float fp32_tmp1 = distribution(generator); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + struct GeneratorTensor_Checkboard { template @@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential } }; +template +struct GeneratorTensor_Sequential +{ + template + ck::f4x2_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + return ck::type_convert(ck::float2_t(tmp)); + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::f6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::bf6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + template struct GeneratorTensor_Diagonal { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 66c4958e1d..ad48389625 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, 0, @@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; @@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const bf8x32_t& reg_a, const int32_t& scale_a, @@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, @@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, reg_c.template AsType()[Number<0>{}], - 4, // cbsz + 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 4, // blgp 0, // OPSEL scale_a, @@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, @@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ignore = reg_b; ignore = scale_b; ignore = reg_c; -#endif - } - - template - __device__ static void Run(const bf8x32_t& reg_a, - const int32_t& scale_a, - const f8x32_t& reg_b, - const int32_t& scale_b, - FloatC& reg_c) - { -#if defined(__gfx950__) - // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - reg_a, - reg_b, - reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL - scale_a, - 0, // OPSEL - scale_b); -#else - ignore = reg_a; - ignore = scale_a; - ignore = reg_b; - ignore = scale_b; - ignore = reg_c; #endif } }; @@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, 0, @@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a6106bb146..c11b9c0272 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format +// scalar_type +template +struct scalar_type; + struct f4x2_pk_t { + static constexpr int packed_size = 2; + using type = uint8_t; type data; __host__ __device__ f4x2_pk_t() : data{type{}} {} @@ -55,269 +61,82 @@ struct f4x2_pk_t } }; -struct f6x16_pk_t +template +struct f6_pk_t { - // store 16 elements of f6_t in an array of 3 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); - f6x16_pk_t() : data{type{}} {} - f6x16_pk_t(type init) : data{init} {} + using element_type = uint32_t; // element storage fundamental type - template - __host__ __device__ inline f6_t unpack(Number) + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_elem = 6; + static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT; + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + + using storage_type = StaticallyIndexedArray_v2; + storage_type data; // packed data + + using type = f6_pk_t; + + __host__ __device__ constexpr f6_pk_t() : data{} {} + __host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {} + template ::vector_size == packed_size>> + __host__ __device__ f6_pk_t(const T& v) : data{} { - static_assert(I < 16, "Index out of range for 16 f6_t elements."); + static_for<0, packed_size, 1>{}( + [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); + } - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + template + __host__ __device__ void pack(const T x, const index_t i) + { + static_assert(is_integral::value || is_same_v, + "T must be an integral type."); - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + uint32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = data.data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data.data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + uint32_t next_value = data.data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data.data_[arr_index + 1] = next_value; + } + } + + __host__ __device__ static inline BitType unpack(const type& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + uint32_t bits = pk.data.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); } - return static_cast(bits & 0x3F); + return static_cast(bits & 0x3F); } - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 16 f6_t values, place its 6 bits in the correct position - ck::static_for<0, 16, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } + __host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); } }; -struct f6x32_pk_t -{ - // store 32 elements of f6_t in an array of 6 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); - f6x32_pk_t() : data{type{}} {} - f6x32_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline f6_t unpack(Number) - { - static_assert(I < 32, "Index out of range for 32 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 32 f6_t values, place its 6 bits in the correct position - ck::static_for<0, 32, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; - -struct bf6x16_pk_t -{ - // store 16 elements of bf6_t in an array of 3 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); - bf6x16_pk_t() : data{type{}} {} - bf6x16_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline bf6_t unpack(Number) - { - static_assert(I < 16, "Index out of range for 16 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 16 bf6_t values, place its 6 bits in the correct position - ck::static_for<0, 16, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; - -struct bf6x32_pk_t -{ - // store 32 elements of bf6_t in an array of 6 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); - bf6x32_pk_t() : data{type{}} {} - bf6x32_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline bf6_t unpack(Number) - { - static_assert(I < 32, "Index out of range for 32 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 32 bf6_t values, place its 6 bits in the correct position - ck::static_for<0, 32, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; +using f6x16_pk_t = f6_pk_t; +using f6x32_pk_t = f6_pk_t; +using bf6x16_pk_t = f6_pk_t; +using bf6x32_pk_t = f6_pk_t; // custom data type - pack int4 data struct pk_i4_t @@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x) } // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool, f4_t, f6_t, bf6_t +// native types: bool template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value; + is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value; } // scalar_type @@ -484,6 +302,106 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +// Default behavior for types that do not need special handling +template +struct packed_type +{ + using type = T; + static constexpr index_t packed_size = 1; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = pk_i4_t; + static constexpr index_t packed_size = 2; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = f4x2_pk_t; + static constexpr index_t packed_size = 2; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = f6x32_pk_t; + static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = bf6x32_pk_t; + static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements +}; + +template +using packed_type_t = typename packed_type::type; + +// Check if the type has packed type specialization +template +inline constexpr bool has_packed_type_v = !is_same_v, T>; + +template +struct element_type +{ + private: + static constexpr auto get_element_type() + { + using U = remove_cvref_t; + if constexpr(is_same_v) + return int4_t{}; + else if constexpr(is_same_v) + return f4_t{}; + else if constexpr(is_same_v) + return f6_t{}; + else if constexpr(is_same_v) + return bf6_t{}; + else if constexpr(is_same_v) + return f6_t{}; + else if constexpr(is_same_v) + return bf6_t{}; + else + return T{}; + } + + public: + using type = decltype(get_element_type()); +}; +template +using element_type_t = typename element_type::type; + +template +inline constexpr bool is_packed_type_v = + has_packed_type_v>&& is_same_v>>; + +template +struct packed_size +{ + private: + static constexpr auto get_packed_size() + { + using U = remove_cvref_t; + if constexpr(is_packed_type_v) + return Number>::packed_size>{}; + else + return Number::packed_size>{}; + } + + public: + using type = decltype(get_packed_size()); + static constexpr auto value = get_packed_size(); +}; + +template +using packed_size_t = typename packed_size::type; + +template +inline constexpr index_t packed_size_v = packed_size::value; + #if defined(_WIN32) using int64_t = long long; #else diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 9c40d923d3..65eed0624c 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -365,6 +365,88 @@ struct vector_type()>> } }; +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + typedef T d6_t __attribute__((ext_vector_type(6))); + + using type = d6_t; + + union + { + d6_t d6_; + StaticallyIndexedArray d1x6_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d3x2_; + StaticallyIndexedArray d6x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } +}; + template struct vector_type()>> { @@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector template <> struct nnvb_data_t_selector { - using type = f6x16_pk_t::type; + using type = f6x16_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = f6x32_pk_t::type; + using type = f6x32_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = bf6x16_pk_t::type; + using type = bf6x16_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = bf6x32_pk_t::type; + using type = bf6x32_pk_t::storage_type; }; template <> @@ -1406,12 +1488,23 @@ struct non_native_vector_base -struct scalar_type> +struct scalar_type>> { using type = typename non_native_vector_base::data_t; static constexpr index_t vector_size = N; }; +template +struct scalar_type< + non_native_vector_base>> +{ + using type = typename non_native_vector_base::element_t; + static constexpr index_t vector_size = N * non_native_vector_base::size_factor; +}; + // non-native vector_type implementation template struct vector_type()>> @@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type::type; // i32 using int32x2_t = typename vector_type::type; using int32x4_t = typename vector_type::type; +using int32x6_t = typename vector_type::type; using int32x8_t = typename vector_type::type; using int32x16_t = typename vector_type::type; using int32x32_t = typename vector_type::type; diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index b0b5297f77..53edb6e182 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type(float value) if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) - return value < 0 ? NumericUtils::negative_zero_mask - : NumericUtils::positive_zero_mask; + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; return res; } @@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) - return value < 0 ? NumericUtils::negative_zero_mask - : NumericUtils::positive_zero_mask; + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; return res; } diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index cf68188b3e..a840c520a9 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr(float value, uint32 return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr(float value, uint if(std::isnan(value)) return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index c8d284a1d7..ed07e53e6d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -89,6 +89,14 @@ struct ReferenceGemm : public device::BaseOperator v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + v_a = type_convert( + arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -115,6 +123,14 @@ struct ReferenceGemm : public device::BaseOperator v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + v_b = type_convert( + arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index e8fdcf1acd..3fc39911dd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -105,6 +105,16 @@ struct ReferenceMXGemm : public device::BaseOperator type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } else { a_m_k_scaled(m, k) = @@ -134,6 +144,16 @@ struct ReferenceMXGemm : public device::BaseOperator type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } else { b_k_n_scaled(k, n) = diff --git a/test/data_type/test_bf6.cpp b/test/data_type/test_bf6.cpp index a260f81d16..9dbb77454c 100644 --- a/test/data_type/test_bf6.cpp +++ b/test/data_type/test_bf6.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" using ck::bf6_convert_rne; @@ -41,6 +42,11 @@ TEST(BF6, ConvertFP32Nearest) ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(std::numeric_limits::infinity())), 0.0f); + + // convert float +/-30 to bf6 and back, check if clipped to +/-max_bf6 + ASSERT_NEAR(-max_bf6, type_convert(bf6_convert_rne(-30.0f)), 0.0f); + ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(30.0f)), 0.0f); + // convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0 float less_than_subnorm = 0.03125f; ASSERT_NEAR(0.0f, type_convert(bf6_convert_rne(less_than_subnorm)), 0.0f); @@ -266,21 +272,18 @@ TEST(BF6, TestAsType16x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = bf6x16_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); }); } @@ -329,23 +332,23 @@ TEST(BF6, TestAsType16x2) // check default CTOR ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(right_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).unpack(idx_element), + 0); }); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec[i]); + right_vec.template AsType()(Number{}) = bf6x16_pk_t{test_vec[i]}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(left_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - static_cast(test_vec[idx_vector][static_cast(idx_element)])); + ASSERT_EQ( + left_vec.template AsType()(Number{}).unpack(idx_element), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); }); }); } @@ -369,20 +372,86 @@ TEST(BF6, TestAsType32x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x32_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = bf6x32_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); + }); +} + +TEST(BF6, TestAllValues) +{ + + constexpr std::array e3m2ValuesOCP = { + // clang-format off + 0.0000000000, 0.0625000000, 0.1250000000, 0.1875000000, + 0.2500000000, 0.3125000000, 0.3750000000, 0.4375000000, + 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000, + 1.0000000000, 1.2500000000, 1.5000000000, 1.7500000000, + 2.0000000000, 2.5000000000, 3.0000000000, 3.5000000000, + 4.0000000000, 5.0000000000, 6.0000000000, 7.0000000000, + 8.0000000000, 10.0000000000, 12.0000000000, 14.0000000000, + 16.0000000000, 20.0000000000, 24.0000000000, 28.0000000000, + -0.0000000000, -0.0625000000, -0.1250000000, -0.1875000000, + -0.2500000000, -0.3125000000, -0.3750000000, -0.4375000000, + -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000, + -1.0000000000, -1.2500000000, -1.5000000000, -1.7500000000, + -2.0000000000, -2.5000000000, -3.0000000000, -3.5000000000, + -4.0000000000, -5.0000000000, -6.0000000000, -7.0000000000, + -8.0000000000, -10.0000000000, -12.0000000000, -14.0000000000, + -16.0000000000, -20.0000000000, -24.0000000000, -28.0000000000 + // clang-format on + }; + + constexpr uint8_t e3m2BitsOCP[] = { + // clang-format off + 0b000000, 0b000001, 0b000010, 0b000011, + 0b000100, 0b000101, 0b000110, 0b000111, + 0b001000, 0b001001, 0b001010, 0b001011, + 0b001100, 0b001101, 0b001110, 0b001111, + 0b010000, 0b010001, 0b010010, 0b010011, + 0b010100, 0b010101, 0b010110, 0b010111, + 0b011000, 0b011001, 0b011010, 0b011011, + 0b011100, 0b011101, 0b011110, 0b011111, + 0b100000, 0b100001, 0b100010, 0b100011, + 0b100100, 0b100101, 0b100110, 0b100111, + 0b101000, 0b101001, 0b101010, 0b101011, + 0b101100, 0b101101, 0b101110, 0b101111, + 0b110000, 0b110001, 0b110010, 0b110011, + 0b110100, 0b110101, 0b110110, 0b110111, + 0b111000, 0b111001, 0b111010, 0b111011, + 0b111100, 0b111101, 0b111110, 0b111111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("BF6 Table\n"); + ck::static_for<0, 64, 1>{}([&](auto i) { + float fp = type_convert(bf6_t(e3m2BitsOCP[i])); + ASSERT_EQ(fp, e3m2ValuesOCP[i]); + + bf6_t bf6 = type_convert(e3m2ValuesOCP[i]); + ASSERT_EQ(bf6 & 0x3F, e3m2BitsOCP[i] & 0x3F); + + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 5; j >= 0; --j) + { + printf("%c", (e3m2BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e3m2BitsOCP[i], e3m2ValuesOCP[i]); + } }); } diff --git a/test/data_type/test_fp4.cpp b/test/data_type/test_fp4.cpp index f4b2bf3358..3fc74a2ef3 100644 --- a/test/data_type/test_fp4.cpp +++ b/test/data_type/test_fp4.cpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" #include "ck/utility/scaled_type_convert.hpp" +#include "ck/utility/env.hpp" using ck::e8m0_bexp_t; using ck::f4_convert_rne; @@ -38,6 +39,11 @@ TEST(FP4, ConvertFP32Nearest) // convert maximal float to fp4 and back, check if clipped to 6.0 ASSERT_NEAR( max_fp4, type_convert(f4_convert_rne(std::numeric_limits::max())), abs_tol); + + // convert +/-7.0 to fp4 and back, check if clipped to +/-6.0 + ASSERT_NEAR(-max_fp4, type_convert(f4_convert_rne(-7.0f)), 0.0); + ASSERT_NEAR(max_fp4, type_convert(f4_convert_rne(7.0f)), 0.0); + // positive norm float value to fp4 and back, check if holds float pos_float = 1.0f; ASSERT_NEAR(pos_float, type_convert(f4_convert_rne(pos_float)), abs_tol); @@ -468,3 +474,54 @@ TEST(FP4, TestAsType32) test_vec.at(i + 1)); }); } + +TEST(FP4, TestAllValues) +{ + constexpr std::array e2m1ValuesOCP = { + // clang-format off + 0.0000000000, 0.5000000000, + 1.0000000000, 1.5000000000, + 2.0000000000, 3.0000000000, + 4.0000000000, 6.0000000000, + -0.0000000000, -0.5000000000, + -1.0000000000, -1.5000000000, + -2.0000000000, -3.0000000000, + -4.0000000000, -6.0000000000 + // clang-format on + }; + + constexpr uint8_t e2m1BitsOCP[] = { + // clang-format off + 0b0000, 0b0001, + 0b0010, 0b0011, + 0b0100, 0b0101, + 0b0110, 0b0111, + 0b1000, 0b1001, + 0b1010, 0b1011, + 0b1100, 0b1101, + 0b1110, 0b1111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("FP4 Table\n"); + ck::static_for<0, 16, 1>{}([&](auto i) { + float fp = type_convert(f4_t(e2m1BitsOCP[i])); + ASSERT_EQ(fp, e2m1ValuesOCP[i]); + + f4_t fp4 = type_convert(e2m1ValuesOCP[i]); + ASSERT_EQ(fp4 & 0xF, e2m1BitsOCP[i] & 0xF); + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 3; j >= 0; --j) + { + printf("%c", (e2m1BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e2m1BitsOCP[i], e2m1ValuesOCP[i]); + } + }); +} diff --git a/test/data_type/test_fp6.cpp b/test/data_type/test_fp6.cpp index cf91e69db3..6d4aec1d9a 100644 --- a/test/data_type/test_fp6.cpp +++ b/test/data_type/test_fp6.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" using ck::e8m0_bexp_t; @@ -34,6 +35,11 @@ TEST(FP6, ConvertFP32Nearest) ASSERT_NEAR(0.0f, type_convert(f6_convert_rne(0.0f)), 0.0f); // convert maximal f6_t to float and check if equal to max_fp6 ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(max_fp6)), 0.0f); + + // convert maximal +/-8.0 to fp6 and check if equal to +/-max_fp6 + ASSERT_NEAR(-max_fp6, type_convert(f6_convert_rne(-8.0f)), 0.0f); + ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(8.0f)), 0.0f); + // convert maximal float to fp6 and back, check if clipped to max_fp6 ASSERT_NEAR( max_fp6, type_convert(f6_convert_rne(std::numeric_limits::max())), 0.0f); @@ -265,20 +271,24 @@ TEST(FP6, TestAsType16x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = f6x16_pk_t{test_vec}; }); + // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])) + << " i = " << i << "; left = " + << type_convert(left_vec.template AsType()(Number<0>{}).unpack(i)) + << " -- right = " + << type_convert(static_cast(test_vec[static_cast(i)])) << " (" + << static_cast(test_vec[static_cast(i)]) << ")" << std::endl; }); } @@ -327,23 +337,23 @@ TEST(FP6, TestAsType16x2) // check default CTOR ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(right_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).unpack(idx_element), + 0); }); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec[i]); + right_vec.template AsType()(Number{}) = f6x16_pk_t{test_vec[i]}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(left_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - static_cast(test_vec[idx_vector][static_cast(idx_element)])); + ASSERT_EQ( + left_vec.template AsType()(Number{}).unpack(idx_element), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); }); }); } @@ -367,19 +377,77 @@ TEST(FP6, TestAsType32x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x32_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = f6x32_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); + }); +} + +TEST(FP6, TestAllValues) +{ + constexpr std::array e2m3ValuesOCP = { + // clang-format off + 0.0000000000, 0.1250000000, 0.2500000000, 0.3750000000, 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000, + 1.0000000000, 1.1250000000, 1.2500000000, 1.3750000000, 1.5000000000, 1.6250000000, 1.7500000000, 1.8750000000, + 2.0000000000, 2.2500000000, 2.5000000000, 2.7500000000, 3.0000000000, 3.2500000000, 3.5000000000, 3.7500000000, + 4.0000000000, 4.5000000000, 5.0000000000, 5.5000000000, 6.0000000000, 6.5000000000, 7.0000000000, 7.5000000000, + -0.0000000000, -0.1250000000, -0.2500000000, -0.3750000000, -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000, + -1.0000000000, -1.1250000000, -1.2500000000, -1.3750000000, -1.5000000000, -1.6250000000, -1.7500000000, -1.8750000000, + -2.0000000000, -2.2500000000, -2.5000000000, -2.7500000000, -3.0000000000, -3.2500000000, -3.5000000000, -3.7500000000, + -4.0000000000, -4.5000000000, -5.0000000000, -5.5000000000, -6.0000000000, -6.5000000000, -7.0000000000, -7.5000000000 + // clang-format on + }; + + constexpr uint8_t e2m3BitsOCP[] = { + // clang-format off + 0b000000, 0b000001, 0b000010, 0b000011, + 0b000100, 0b000101, 0b000110, 0b000111, + 0b001000, 0b001001, 0b001010, 0b001011, + 0b001100, 0b001101, 0b001110, 0b001111, + 0b010000, 0b010001, 0b010010, 0b010011, + 0b010100, 0b010101, 0b010110, 0b010111, + 0b011000, 0b011001, 0b011010, 0b011011, + 0b011100, 0b011101, 0b011110, 0b011111, + 0b100000, 0b100001, 0b100010, 0b100011, + 0b100100, 0b100101, 0b100110, 0b100111, + 0b101000, 0b101001, 0b101010, 0b101011, + 0b101100, 0b101101, 0b101110, 0b101111, + 0b110000, 0b110001, 0b110010, 0b110011, + 0b110100, 0b110101, 0b110110, 0b110111, + 0b111000, 0b111001, 0b111010, 0b111011, + 0b111100, 0b111101, 0b111110, 0b111111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("FP6 Table\n"); + ck::static_for<0, 64, 1>{}([&](auto i) { + float fp = type_convert(f6_t(e2m3BitsOCP[i])); + ASSERT_EQ(fp, e2m3ValuesOCP[i]); + + f6_t fp6 = type_convert(e2m3ValuesOCP[i]); + ASSERT_EQ(fp6 & 0x3F, e2m3BitsOCP[i] & 0x3F); + + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 5; j >= 0; --j) + { + printf("%c", (e2m3BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e2m3BitsOCP[i], e2m3ValuesOCP[i]); + } }); } diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index fddb8288a6..5e2aedd35e 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -5,9 +5,12 @@ #include "mx_mfma_op.hpp" +using ck::bf6_t; +using ck::bf8_t; using ck::e8m0_bexp_t; using ck::f4_t; using ck::f4x2_pk_t; +using ck::f6_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -17,13 +20,15 @@ using ck::type_convert; * * @param init - selects initialization algorithm for A and B tensors */ -template -bool run_mfma_km_kn_nm_test(ck::index_t init) +template +bool run_mfma_test(ck::index_t init) { - using ALayout = ck::tensor_layout::gemm::ColumnMajor; - using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::ColumnMajor; - using AccType = float; // only MFMA_F32 instructions supported using CPUAccType = AccType; @@ -53,74 +58,153 @@ bool run_mfma_km_kn_nm_test(ck::index_t init) return pass; } +const ck::index_t common_init = -4; // set to "< 0" for test-specific initializations + TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 5; - auto pass = run_mfma_km_kn_nm_test(AB_init); + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } -TEST(MFMA, FP8MFMA32x32x64) +TEST(MFMA, BF8MFMA16x16x128) { - auto AB_init = 5; - auto pass = run_mfma_km_kn_nm_test(AB_init); + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } -/** - * @brief Run the test for the given MFMA instruction - * - * @param init - selects initialization algorithm for A and B tensors - */ -template -bool run_mfma_mk_kn_mn_test(ck::index_t init) +TEST(MFMA, FP4MFMA16x16x128) { using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; - using AccType = float; // only MFMA_F32 instructions supported - using CPUAccType = AccType; - - ck::mfma_type(mfma)> mfma_instr; - constexpr auto BLOCK_M = mfma_instr.m_per_blk; - constexpr auto BLOCK_N = mfma_instr.n_per_blk; - constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - - const auto mfma_kernel = ck:: - matmul; - - bool pass = true; - - pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); - - return pass; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); } -TEST(MFMA, FP4MFMA16x16x128) +TEST(MFMA, FP6MFMA16x16x128) { - auto AB_init = 4; - auto pass = run_mfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP4MFMA32x32x64) { - auto AB_init = 4; - auto pass = run_mfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } @@ -129,15 +213,18 @@ TEST(MFMA, FP4MFMA32x32x64) * * @param init - selects initialization algorithm for A and B tensors */ -template -bool run_mxmfma_mk_kn_mn_test(ck::index_t init) +template +bool run_mxmfma_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); - using ALayout = ck::tensor_layout::gemm::RowMajor; - using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::RowMajor; using AccType = float; // only MFMA_F32 instructions supported using ScaleType = ck::e8m0_bexp_t; // biased exponent type @@ -181,34 +268,170 @@ bool run_mxmfma_mk_kn_mn_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 5; - auto pass = - run_mxmfma_mk_kn_mn_test(AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 5; - auto pass = - run_mxmfma_mk_kn_mn_test(AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF8MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA16x16x128) { - auto AB_init = 4; - auto pass = - run_mxmfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA32x32x64) { - auto AB_init = 4; - auto pass = - run_mxmfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 9ce871cfb1..4cab411cb4 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -151,6 +151,8 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on + static_assert(!is_packed_type_v, "Packed type is not supported"); + static constexpr int32_t WAVE_SIZE = 64; // Here we want to load from rows of A in chunks of 16 elements each. @@ -270,12 +272,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + + // Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] | + // Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] | + // clang-format on static constexpr int32_t WAVE_SIZE = 64; + // FP8 chunk_size = 16, num_chunks = 2, packed_size = 1 + // FP4 chunk_size = 32, num_chunks = 1, packed_size = 2 + // FP6 chunk_size = 32, num_chunks = 1, packed_size = 32 + + constexpr index_t num_chunks = is_packed_type_v ? 1 : 2; + // Here we want to load from rows of A in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + constexpr uint32_t chunk_size = is_packed_type_v ? 32 : 16; // each chunk is separated by offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; @@ -283,43 +301,35 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = - std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} - (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + // FP8/6/4 Row {0-31} | {0-15} + // FP8 Col {0, 16} | {0, 16, 32, 48} + // FP6/4 Col {0, 32} | {0, 32, 64, 96} + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, (threadIdx.x / BLOCK_M) * chunk_size); - // auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; - // BLOCK_K is a stride in A matrix - auto startOffset = row_major( - startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / - // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - auto kMajorOffset = - row_major(majorStepCoord2D, - BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; - - constexpr index_t num_chunks = - (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + using ARawT = typename scalar_type::type; + using AScalarChunkT = vector_type::vector_size / num_chunks>::type; union { AFragT frag; - AScalarFragT chunks[num_chunks]; + AScalarChunkT chunks[num_chunks]; } fragA{}; - const AScalarFragT* fragPtr; + const AScalarChunkT* fragPtr; + + // BLOCK_K is a stride in A matrix + auto startOffset = row_major(startCoord2D, BLOCK_K) / packed_size_v; + auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K) / packed_size_v; for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { - fragPtr = reinterpret_cast(input_ptr + startOffset + - chunk_idx * kMajorOffset); + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); fragA.chunks[chunk_idx] = *fragPtr; } @@ -488,12 +498,27 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + // Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] | + // Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] | + // clang-format on static constexpr int32_t WAVE_SIZE = 64; + // FP8 chunk_size = 16, num_chunks = 2, packed_size = 1 + // FP4 chunk_size = 32, num_chunks = 1, packed_size = 2 + // FP6 chunk_size = 32, num_chunks = 1, packed_size = 32 + + constexpr index_t num_chunks = is_packed_type_v ? 1 : 2; + // Here we want to load from cols of B in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + constexpr uint32_t chunk_size = is_packed_type_v ? 32 : 16; // each chunk is separated by an offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64 @@ -501,44 +526,36 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = - std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48} - threadIdx.x % BLOCK_N); // Col {0-31} | {0-15} + // FP8/6/4 Col {0-31} | {0-15} + // FP8 Row {0, 16} | {0, 16, 32, 48} + // FP6/4 Row {0, 32} | {0, 32, 64, 96} + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, threadIdx.x % BLOCK_N); // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; - // auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - // BLOCK_K is a stride in B matrix - auto startOffset = col_major( - startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / - // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - auto kMajorOffset = - col_major(majorStepCoord2D, - BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - - using BRawT = typename scalar_type::type; - using BScalarFragT = vector_type::type; - - constexpr index_t num_chunks = - (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + using BRawT = typename scalar_type::type; + using BScalarChunkT = vector_type::vector_size / num_chunks>::type; union { BFragT frag; - BScalarFragT chunks[num_chunks]; + BScalarChunkT chunks[num_chunks]; } fragB{}; - const BScalarFragT* fragPtr; + const BScalarChunkT* fragPtr; - for(index_t chunk = 0; chunk < num_chunks; chunk++) + // BLOCK_K is a stride in B matrix + auto startOffset = col_major(startCoord2D, BLOCK_K) / packed_size_v; + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K) / packed_size_v; + + for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { - fragPtr = - reinterpret_cast(input_ptr + startOffset + chunk * kMajorOffset); - fragB.chunks[chunk] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); + fragB.chunks[chunk_idx] = *fragPtr; } return fragB.frag; @@ -904,20 +921,22 @@ template -__global__ void matmul(const AType* a, const BType* b, CType* c) +__global__ void matmul(const typename packed_type::type* a, + const typename packed_type::type* b, + CType* c) { + using PackedAType = typename packed_type::type; + constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + constexpr auto packed_size_b = packed_type::packed_size; + constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; - using BFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -931,11 +950,11 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) // Load the inputs. if constexpr(is_same_v) { - fragA = load_A_row_major(a); + fragA = load_A_row_major(a); } else { - fragA = load_A_col_major(a); + fragA = load_A_col_major(a); } if constexpr(is_same_v) @@ -944,7 +963,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) } else { - fragB = load_B_col_major(b); + fragB = load_B_col_major(b); } // Matrix multiply-accumulate using MFMA units @@ -979,21 +998,24 @@ template -__global__ void -matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) +__global__ void matmul(const packed_type_t* a, + const ScaleType* xa, + const packed_type_t* b, + const ScaleType* xb, + CType* c) { + using PackedAType = packed_type_t; + constexpr auto packed_size_a = packed_size_v; + using PackedBType = packed_type_t; + constexpr auto packed_size_b = packed_size_v; + constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; - using BFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -1011,9 +1033,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, // Load the inputs. if constexpr(is_same_v) { - fragA = - load_mx_A_row_major( - a, xa, fragXa); + fragA = load_mx_A_row_major(a, xa, fragXa); } else { @@ -1026,9 +1052,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, } else { - fragB = - load_mx_B_col_major( - b, xb, fragXb); + fragB = load_mx_B_col_major(b, xb, fragXb); } // Scaled Matrix multiply-accumulate using MFMA units @@ -1151,6 +1181,11 @@ template struct TestMXMFMA { + using PackedAType = typename packed_type::type; + static constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + static constexpr auto packed_size_b = packed_type::packed_size; + auto PrepareGemmTensors(const GemmParams& params, index_t init) { auto f_host_tensor_descriptor = @@ -1167,11 +1202,11 @@ struct TestMXMFMA } }; - Tensor a_m_k( + Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor a_scales( f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{})); - Tensor b_n_k( + Tensor b_n_k( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_scales( f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{})); @@ -1183,51 +1218,44 @@ struct TestMXMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.5f}}); // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 - b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); break; case 1: // results in C = {K} - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); - b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); break; case 2: // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); break; case 3: // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1, time(nullptr))); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); - b_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - break; - case 4: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); - a_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1, time(nullptr) / 2)); b_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} break; + default: // all initial values are representable in FP8, BF8 - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] a_scales.GenerateTensorValue( - GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] b_scales.GenerateTensorValue( GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] @@ -1272,9 +1300,9 @@ struct TestMXMFMA auto host_tensors = PrepareGemmTensors(params, init); - const Tensor& a = std::get<0>(host_tensors); + const Tensor& a = std::get<0>(host_tensors); const Tensor& a_scales = std::get<1>(host_tensors); - const Tensor& b = std::get<2>(host_tensors); + const Tensor& b = std::get<2>(host_tensors); const Tensor& b_scales = std::get<3>(host_tensors); Tensor& c_host = std::get<4>(host_tensors); Tensor& c_device = std::get<5>(host_tensors); @@ -1356,6 +1384,12 @@ template struct TestMFMA { + + using PackedAType = typename packed_type::type; + static constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + static constexpr auto packed_size_b = packed_type::packed_size; + auto PrepareGemmTensors(const GemmParams& params, index_t init) { auto f_host_tensor_descriptor = @@ -1372,9 +1406,9 @@ struct TestMFMA } }; - Tensor a_m_k( + Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_n_k( + Tensor b_n_k( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); @@ -1384,34 +1418,30 @@ struct TestMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{0.015625f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{0.625f}); // NOTE: not all numbers are representable in FP8, BF8, etc. - b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); break; case 1: // results in C = {K} - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); break; case 2: - // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + // expect small round off errors that lead to FP8MFMA32x32x64 failures + a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); break; case 3: - // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); - b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); - break; - case 4: - // FP4 values case - a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + // expect small round off errors that lead to FP8MFMA32x32x64 failures + a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); + b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; + default: - // all initial values are representable in FP8, BF8 - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + // all initial values are representable in FP8/6, BF8/6 FP4 is missing 5 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); break; } @@ -1453,10 +1483,10 @@ struct TestMFMA auto host_tensors = PrepareGemmTensors(params, init); - const Tensor& a = std::get<0>(host_tensors); - const Tensor& b = std::get<1>(host_tensors); - Tensor& c_host = std::get<2>(host_tensors); - Tensor& c_device = std::get<3>(host_tensors); + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -1464,8 +1494,8 @@ struct TestMFMA auto b_element_op = PassThrough{}; auto c_element_op = PassThrough{}; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm