Implement the fp16xint4 scale weight only kernel for Ali (#1786)

* enable int4 scale (weight only) kernel

* format some files

* Add unit test for int4 weight only

* fixed and formatted code

* fixed

* formated

* formated

* fixed

* fixed a bug in the ckProfiler, and formatted the code

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
Mingtao Gu
2025-01-03 18:35:21 +08:00
committed by GitHub
parent 4bc610416a
commit 4f62f6e9b7
21 changed files with 7562 additions and 4 deletions

View File

@@ -1222,6 +1222,206 @@ struct ThreadwiseTensorSliceTransfer_v4
});
}
// Fuse scale
template <typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf,
const DstData& scale,
const DstDesc&,
const DstOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
// scalar per access of each dim
constexpr auto src_scalar_per_access = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<SrcScalarPerVector>{};
}
else
{
return Number<1>{};
}
},
Number<nDim>{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<1>{};
}
else
{
return Number<0>{};
}
},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif
// src coordinate
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
vector_type<DstData, 2> scale_vector;
scale_vector.template AsType<DstData>()(Number<0>{}) = scale;
scale_vector.template AsType<DstData>()(Number<1>{}) = scale;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using scale_v_t = typename vector_type_maker_t<DstData, 2>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::DequantPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i],
scale_vector.template AsType<scale_v_t>()[Number<0>{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 2;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack2{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
});
}
template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx)