mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
Implement the fp16xint4 scale weight only kernel for Ali (#1786)
* enable int4 scale (weight only) kernel * format some files * Add unit test for int4 weight only * fixed and formatted code * fixed * formated * formated * fixed * fixed a bug in the ckProfiler, and formatted the code --------- Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -1222,6 +1222,206 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
});
|
||||
}
|
||||
|
||||
// Fuse scale
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
typename DstOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcRefToOriginDisplacement&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstData& scale,
|
||||
const DstDesc&,
|
||||
const DstOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
|
||||
// SrcDesc and DstDesc are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
|
||||
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
|
||||
|
||||
// scalar per access of each dim
|
||||
constexpr auto src_scalar_per_access = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Number<SrcScalarPerVector>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// scalar step (if steping on SrcVectorDim) of each dim
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
#if 0
|
||||
// TODO: unable to compile
|
||||
// position in slice window
|
||||
constexpr auto data_to_origin_disp_idx =
|
||||
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
#else
|
||||
// position in slice window
|
||||
constexpr auto data_to_origin_disp_idx =
|
||||
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
|
||||
#endif
|
||||
// src coordinate
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_step =
|
||||
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_data_coord);
|
||||
|
||||
// copy data from src_buf into src_tmp_vector
|
||||
if constexpr(SrcBuffer::IsDynamicBuffer())
|
||||
{
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
|
||||
is_src_valid);
|
||||
}
|
||||
else if constexpr(SrcBuffer::IsStaticBuffer())
|
||||
{
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
vector_type<DstData, 2> scale_vector;
|
||||
scale_vector.template AsType<DstData>()(Number<0>{}) = scale;
|
||||
scale_vector.template AsType<DstData>()(Number<1>{}) = scale;
|
||||
|
||||
constexpr index_t pack_size = 8;
|
||||
|
||||
static_assert(SrcScalarPerVector % pack_size == 0, "");
|
||||
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
using scale_v_t = typename vector_type_maker_t<DstData, 2>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::DequantPack8{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i],
|
||||
scale_vector.template AsType<scale_v_t>()[Number<0>{}]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
constexpr index_t pack_size = 2;
|
||||
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
|
||||
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack2{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcSliceMoveStepIdx>
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc&,
|
||||
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
|
||||
|
||||
Reference in New Issue
Block a user