mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK TILE] GEMM with packed i4 (#1885)
* [CK TILE] GEMM with packed i4 * Fixes * fixes * fixes * fixes
This commit is contained in:
@@ -281,18 +281,18 @@ struct HostTensor
|
||||
using Data = std::vector<T>;
|
||||
|
||||
template <typename X>
|
||||
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size())
|
||||
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.get_element_space_size())
|
||||
: mDesc(lens, strides), mData(get_element_space_size())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size())
|
||||
HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
|
||||
{
|
||||
}
|
||||
|
||||
@@ -302,7 +302,7 @@ struct HostTensor
|
||||
{
|
||||
}
|
||||
|
||||
HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {}
|
||||
HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {}
|
||||
|
||||
template <typename OutT>
|
||||
HostTensor<OutT> CopyAsType() const
|
||||
@@ -340,7 +340,11 @@ struct HostTensor
|
||||
|
||||
std::size_t get_element_size() const { return mDesc.get_element_size(); }
|
||||
|
||||
std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); }
|
||||
std::size_t get_element_space_size() const
|
||||
{
|
||||
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
return mDesc.get_element_space_size() / PackedSize;
|
||||
}
|
||||
|
||||
std::size_t get_element_space_size_in_bytes() const
|
||||
{
|
||||
@@ -463,29 +467,27 @@ struct HostTensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
|
||||
|
||||
Reference in New Issue
Block a user