mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add example of Gemm + AddAddFastGelu (data type: int4) (#369)
* Add custom target to bundle examples together * Add int4 example conditionally (just copy from int8 example) * Extract common code into common.hpp * Move ref gemm type alias into data-type-specific sources * Add #error directive to prevent compile with wrong setting * Let AddAddFastGelu support int4 parameter type * Let check_err() support int4 parameter type * Add wrapper function to hide value conversion while copying memory * Finish int4 example for GEMM + AddAddFastGelu * Add new DeviceMem API to copy memory * Use new DeviceMem API to implement examples * Fix wrongly use of macro 'CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4' * Revert "Add new DeviceMem API to copy memory" This reverts commite26e7af71e. * Add conversion ctor for Tensor<> * Add 'const' specifier to Tensor<>::CopyAsType() * Convert Tensor<> values before/after transfer between host & device [ROCm/composable_kernel commit:2327f1a640]
This commit is contained in:
@@ -150,7 +150,12 @@ check_err(const std::vector<T>& out,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value && !std::is_same<T, bhalf_t>::value, bool>::type
|
||||
std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| std::is_same_v<T, int4_t>
|
||||
#endif
|
||||
,
|
||||
bool>
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
|
||||
@@ -254,7 +254,7 @@ struct Tensor
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
|
||||
|
||||
template <typename OutT>
|
||||
Tensor<OutT> CopyAsType()
|
||||
Tensor<OutT> CopyAsType() const
|
||||
{
|
||||
Tensor<OutT> ret(mDesc);
|
||||
for(size_t i = 0; i < mData.size(); i++)
|
||||
@@ -264,13 +264,18 @@ struct Tensor
|
||||
return ret;
|
||||
}
|
||||
|
||||
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
|
||||
Tensor() = delete;
|
||||
Tensor(const Tensor&) = default;
|
||||
Tensor(Tensor&&) = default;
|
||||
|
||||
Tensor& operator=(const Tensor& other)
|
||||
~Tensor() = default;
|
||||
|
||||
Tensor& operator=(const Tensor&) = default;
|
||||
Tensor& operator=(Tensor&&) = default;
|
||||
|
||||
template <typename FromT>
|
||||
explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
|
||||
{
|
||||
mDesc = other.mDesc;
|
||||
mData = other.mData;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
|
||||
|
||||
Reference in New Issue
Block a user