diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index 7a4ba2eb79..3b696d8cc8 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -76,6 +76,7 @@ set_slice_tile(static_distributed_tensor slice_ends) { using DstDistribution = remove_cvref_t; + using SrcDistribution = remove_cvref_t; constexpr auto sliced_dstr_yidx_ylen = detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); @@ -84,9 +85,10 @@ set_slice_tile(static_distributed_tensor(); constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); - static_assert(std::is_same_v, "wrong!"); + static_assert(std::is_same_v, SrcDistribution>, "wrong!"); - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); + dst_tile.set_y_sliced_thread_data( + sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 21f21e1aa0..7ae624cafc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -300,6 +300,10 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl>; template +using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl< + WarpGemmAttributeMfma, + AttrNumAccess>>; +template using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index d66438528e..d1b14721f2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -93,11 +93,34 @@ struct WarpGemmAttributeMfma Impl{}(c_vec, a_vec, b_vec, bool_constant{}); } + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale, + bool_constant = {}) const + { + Impl{}.template operator()( + c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant{}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { return Impl{}(a_vec, b_vec); } + + // c_vec = a_vec * b_vec + template + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale) const + { + auto c_vec = Impl{}.template operator()(a_vec, a_scale, b_vec, b_scale); + } }; template using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; +template +struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = pk_fp4_t; + using BDataType = pk_fp4_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 128; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + auto arg_a = bit_cast(a_vec); + auto arg_b = bit_cast(b_vec); + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + c_vec, + 4, + 4, + opselA, + a_scale, + opselB, + b_scale); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_scale; +#endif + } + + // c_vec = a_vec * b_vec + template + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale) const + { +#if defined(__gfx950__) + auto arg_a = bit_cast(a_vec); + auto arg_b = bit_cast(b_vec); + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + CVecType{0.f}, + 4, + 4, + opselA, + a_scale, + opselB, + b_scale)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_scale; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 924f7c4a54..04d36cf0ea 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -122,6 +122,8 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; }; + template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index 38fd0d408b..c38175d345 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -92,6 +92,39 @@ struct WarpGemmImpl c.get_thread_buffer().template set_as(I0, c_vec); } + template + CK_TILE_DEVICE void operator()(CTensor& c, + const ATensor& a, + const BTensor& b, + const int32_t& a_scale, + const int32_t& b_scale, + bool_constant = {}) const + { + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; + + // c_vec += a_vec * b_vec + WarpGemmAttribute{}.template operator()( + c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant{}); + + c.get_thread_buffer().template set_as(I0, c_vec); + } + template CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const { @@ -116,6 +149,35 @@ struct WarpGemmImpl return c; } + + template + CK_TILE_DEVICE auto operator()(const ATensor& a, + const BTensor& b, + const int32_t& a_scale, + const int32_t& b_scale) const + { + using CTensor = CWarpTensor; + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + CTensor c; + + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + + // c_vec = a_vec * b_vec + auto c_vec = + WarpGemmAttribute{}.template operator()(a_vec, a_scale, b_vec, b_scale); + + c.get_thread_buffer().template set_as(I0, c_vec); + + return c; + } }; } // namespace ck_tile diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index b2e615b9d8..47b4419379 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -789,12 +789,12 @@ struct store_C_col_major CScalarFragT chunks[vectorSize(CFragT{}) / VW]; } fragC{cFrag}; // Initialize with input fragment - *(reinterpret_cast(output + startOffset)) = fragC.chunks[0]; - *(reinterpret_cast(output + startOffset + kMajorOffset)) = fragC.chunks[1]; - *(reinterpret_cast(output + startOffset + 2 * kMajorOffset)) = - fragC.chunks[2]; - *(reinterpret_cast(output + startOffset + 3 * kMajorOffset)) = - fragC.chunks[3]; + CScalarFragT* fragPtr; + for(uint32_t idx = 0; idx < vectorSize(CFragT{}) / VW; ++idx) + { + fragPtr = reinterpret_cast(output + startOffset + idx * kMajorOffset); + *fragPtr = fragC.chunks[idx]; + } } };