mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)
* add a prototype of int4 * clean * debug * clean * clean * move packed into dynamic_buffer * fixed coord reset * add fast pki4 to half conversion * fix * fixed reference and host_tensor * fixed tensor init * format * debug i4_to_f16_convert * format * fixed splitk * weight permute * add b tile permute * clean * weight permute with splitki * format * improve weight layout * add and_or_b32 * fixed splitk crush * add permute switch as a template * recover v3r1 * clean * failure with intrawave v2 * fixed * fixed * add ckProfiler * add bfp16 support * add bf16 example * fixed int4 to bhalf_t conversion * format * fixed int4 to bf16 conversion * clean * add instances for mem * clean * fixed host tensor size * fixed * debug * fixed * add pk_i4_t as a struct * fix * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * revert * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed comments * revert * clean * revert * revert * fixed * Update CMakeLists.txt * Update script/cmake-ck-dev.sh Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update CMakeLists.txt Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed * fixed * fixed * revert * revert * add comments * format * fixed assert * fixed * Fix I4 define in ckProfiler * Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
@@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
{
|
||||
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
@@ -1109,7 +1121,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
|
||||
@@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
if constexpr(SrcBuffer::IsDynamicBuffer())
|
||||
{
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
|
||||
is_src_valid);
|
||||
}
|
||||
else if constexpr(SrcBuffer::IsStaticBuffer())
|
||||
{
|
||||
@@ -1133,9 +1146,36 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
constexpr index_t pack_size = 8;
|
||||
|
||||
static_assert(SrcScalarPerVector % pack_size == 0, "");
|
||||
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack8{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
|
||||
@@ -31,8 +31,8 @@ template <typename SliceLengths,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarPerVector_,
|
||||
index_t DstScalarPerVector_,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
|
||||
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
|
||||
static_assert(
|
||||
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
else if constexpr(is_detected<is_pack4_invocable_t,
|
||||
decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
else if constexpr(is_detected<is_pack2_invocable_t,
|
||||
decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
#else
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
|
||||
"in-register transpose is not supported for pk_i4_t");
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
constexpr auto packed_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
|
||||
|
||||
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
// apply DstElementwiseOperation
|
||||
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
|
||||
|
||||
dst_vector_container.template AsType<DstData>()(i) = dst_v;
|
||||
});
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
dst_coord_.GetOffset() / PackedSize,
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user