Add example of Gemm + AddAddFastGelu (data type: int4) (#369)

* Add custom target to bundle examples together

* Add int4 example conditionally (just copy from int8 example)

* Extract common code into common.hpp

* Move ref gemm type alias into data-type-specific sources

* Add #error directive to prevent compile with wrong setting

* Let AddAddFastGelu support int4 parameter type

* Let check_err() support int4 parameter type

* Add wrapper function to hide value conversion while copying memory

* Finish int4 example for GEMM + AddAddFastGelu

* Add new DeviceMem API to copy memory

* Use new DeviceMem API to implement examples

* Fix wrongly use of macro 'CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4'

* Revert "Add new DeviceMem API to copy memory"

This reverts commit e26e7af71e.

* Add conversion ctor for Tensor<>

* Add 'const' specifier to Tensor<>::CopyAsType()

* Convert Tensor<> values before/after transfer between host & device

[ROCm/composable_kernel commit: 2327f1a640]
This commit is contained in:
Po Yen Chen
2022-08-23 23:38:41 +08:00
committed by GitHub
parent 1fbd80b0a0
commit 4e53b7beea
11 changed files with 267 additions and 192 deletions

View File

@@ -150,7 +150,12 @@ check_err(const std::vector<T>& out,
}
template <typename T>
typename std::enable_if<std::is_integral<T>::value && !std::is_same<T, bhalf_t>::value, bool>::type
std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t>
#endif
,
bool>
check_err(const std::vector<T>& out,
const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!",

View File

@@ -254,7 +254,7 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType()
Tensor<OutT> CopyAsType() const
{
Tensor<OutT> ret(mDesc);
for(size_t i = 0; i < mData.size(); i++)
@@ -264,13 +264,18 @@ struct Tensor
return ret;
}
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
Tensor() = delete;
Tensor(const Tensor&) = default;
Tensor(Tensor&&) = default;
Tensor& operator=(const Tensor& other)
~Tensor() = default;
Tensor& operator=(const Tensor&) = default;
Tensor& operator=(Tensor&&) = default;
template <typename FromT>
explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
{
mDesc = other.mDesc;
mData = other.mData;
return *this;
}
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }