mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)
* add a prototype of int4 * clean * debug * clean * clean * move packed into dynamic_buffer * fixed coord reset * add fast pki4 to half conversion * fix * fixed reference and host_tensor * fixed tensor init * format * debug i4_to_f16_convert * format * fixed splitk * weight permute * add b tile permute * clean * weight permute with splitki * format * improve weight layout * add and_or_b32 * fixed splitk crush * add permute switch as a template * recover v3r1 * clean * failure with intrawave v2 * fixed * fixed * add ckProfiler * add bfp16 support * add bf16 example * fixed int4 to bhalf_t conversion * format * fixed int4 to bf16 conversion * clean * add instances for mem * clean * fixed host tensor size * fixed * debug * fixed * add pk_i4_t as a struct * fix * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * revert * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed comments * revert * clean * revert * revert * fixed * Update CMakeLists.txt * Update script/cmake-ck-dev.sh Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update CMakeLists.txt Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed * fixed * fixed * revert * revert * add comments * format * fixed assert * fixed * Fix I4 define in ckProfiler * Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -266,18 +266,18 @@ struct Tensor
|
||||
using Data = std::vector<T>;
|
||||
|
||||
template <typename X>
|
||||
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
|
||||
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
|
||||
: mDesc(lens, strides), mData(GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
|
||||
Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
@@ -287,7 +287,7 @@ struct Tensor
|
||||
{
|
||||
}
|
||||
|
||||
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
|
||||
Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
|
||||
|
||||
template <typename OutT>
|
||||
Tensor<OutT> CopyAsType() const
|
||||
@@ -322,7 +322,17 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
|
||||
|
||||
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
|
||||
std::size_t GetElementSpaceSize() const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
{
|
||||
return (mDesc.GetElementSpaceSize() + 1) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mDesc.GetElementSpaceSize();
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
|
||||
|
||||
@@ -469,29 +479,64 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
}
|
||||
|
||||
typename Data::iterator begin() { return mData.begin(); }
|
||||
|
||||
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::pk_i4_t>
|
||||
{
|
||||
int8_t value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::pk_i4_t operator()(Is...)
|
||||
{
|
||||
int t = value + 8;
|
||||
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::pk_i4_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::pk_i4_t operator()(Is...)
|
||||
{
|
||||
int hi = std::rand() % (max_value - min_value) + min_value + 8;
|
||||
int lo = std::rand() % (max_value - min_value) + min_value + 8;
|
||||
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f8_t>
|
||||
|
||||
Reference in New Issue
Block a user