[CK TILE] GEMM with packed i4 (#1885)

* [CK TILE] GEMM with packed i4

* Fixes

* fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-20 09:59:49 +01:00
committed by GitHub
parent 824e2c1737
commit 4d9973ec8e
32 changed files with 882 additions and 305 deletions

View File

@@ -281,18 +281,18 @@ struct HostTensor
using Data = std::vector<T>;
template <typename X>
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size())
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
{
}
template <typename X, typename Y>
HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.get_element_space_size())
: mDesc(lens, strides), mData(get_element_space_size())
{
}
template <typename Lengths>
HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size())
HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
{
}
@@ -302,7 +302,7 @@ struct HostTensor
{
}
HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {}
HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {}
template <typename OutT>
HostTensor<OutT> CopyAsType() const
@@ -340,7 +340,11 @@ struct HostTensor
std::size_t get_element_size() const { return mDesc.get_element_size(); }
std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); }
std::size_t get_element_space_size() const
{
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.get_element_space_size() / PackedSize;
}
std::size_t get_element_space_size_in_bytes() const
{
@@ -463,29 +467,27 @@ struct HostTensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
return mData[GetOffsetFromMultiIndex(is...)];
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
return mData[GetOffsetFromMultiIndex(is...)];
}
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
return mData[GetOffsetFromMultiIndex(idx)];
}
HostTensor<T> transpose(std::vector<size_t> axes = {}) const