mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit 'bedade257241fef37a28c6e540e73f1c056d27b9' into develop
This commit is contained in:
@@ -76,6 +76,7 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
|
||||
using SrcDistribution = remove_cvref_t<SrcStaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
|
||||
@@ -84,9 +85,10 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
static_assert(std::is_same_v<remove_cvref_t<decltype(sliced_dstr)>, SrcDistribution>, "wrong!");
|
||||
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
dst_tile.set_y_sliced_thread_data(
|
||||
sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -300,6 +300,10 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
|
||||
2>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
@@ -93,11 +93,34 @@ struct WarpGemmAttributeMfma
|
||||
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}.template operator()<opselA, opselB>(
|
||||
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return Impl{}(a_vec, b_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
auto c_vec = Impl{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
|
||||
@@ -1621,6 +1621,98 @@ template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = pk_fp4_t;
|
||||
using BDataType = pk_fp4_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 128;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 32;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
auto arg_a = bit_cast<int32x4_t>(a_vec);
|
||||
auto arg_b = bit_cast<int32x4_t>(b_vec);
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
c_vec,
|
||||
4,
|
||||
4,
|
||||
opselA,
|
||||
a_scale,
|
||||
opselB,
|
||||
b_scale);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_scale;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
auto arg_a = bit_cast<int32x4_t>(a_vec);
|
||||
auto arg_b = bit_cast<int32x4_t>(b_vec);
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
CVecType{0.f},
|
||||
4,
|
||||
4,
|
||||
opselA,
|
||||
a_scale,
|
||||
opselB,
|
||||
b_scale));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_scale;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
{
|
||||
|
||||
@@ -122,6 +122,8 @@ template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16,
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, true> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, true> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
|
||||
@@ -92,6 +92,39 @@ struct WarpGemmImpl
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <index_t opselA,
|
||||
index_t opselB,
|
||||
typename CTensor,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c,
|
||||
const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}.template operator()<opselA, opselB>(
|
||||
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
|
||||
{
|
||||
@@ -116,6 +149,35 @@ struct WarpGemmImpl
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
template <index_t opselA, index_t opselB, typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
CTensor c;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
auto c_vec =
|
||||
WarpGemmAttribute{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -789,12 +789,12 @@ struct store_C_col_major<CType, CFragT, 32, 32>
|
||||
CScalarFragT chunks[vectorSize(CFragT{}) / VW];
|
||||
} fragC{cFrag}; // Initialize with input fragment
|
||||
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset)) = fragC.chunks[0];
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset + kMajorOffset)) = fragC.chunks[1];
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 2 * kMajorOffset)) =
|
||||
fragC.chunks[2];
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 3 * kMajorOffset)) =
|
||||
fragC.chunks[3];
|
||||
CScalarFragT* fragPtr;
|
||||
for(uint32_t idx = 0; idx < vectorSize(CFragT{}) / VW; ++idx)
|
||||
{
|
||||
fragPtr = reinterpret_cast<CScalarFragT*>(output + startOffset + idx * kMajorOffset);
|
||||
*fragPtr = fragC.chunks[idx];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user