Add blockwise gemm to ck wrapper (#1139)

* Add blockwise gemm to ck wrapper

* Add blockwise gemm traits

* Disable test_gemm for non xdl devices

* Fixes

* Add c layout descritpions
This commit is contained in:
Bartłomiej Kocot
2024-01-31 21:24:40 +01:00
committed by GitHub
parent 6651a124cc
commit f3b6c23ac5
12 changed files with 1064 additions and 116 deletions

View File

@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
@@ -19,9 +20,9 @@ namespace wrapper {
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - LDS,
* - SGPR,
* - VGPR,
* - Lds,
* - Sgpr,
* - Vgpr,
*/
using MemoryTypeEnum = AddressSpaceEnum;
@@ -52,12 +53,8 @@ struct Slice
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
is_same_v<std::remove_const_t<T>, index_t>)
{
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
{
throw std::runtime_error("Invalid range");
}
if(to_ < 0)
{
return dim - from_ + to_ + 1;
@@ -70,9 +67,10 @@ struct Slice
}
else
{
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
(ToType{} < 0 || ToType{} > FromType{}),
"Invalid range");
if constexpr(to_ < 0)
if constexpr(ToType{} < 0)
{
return dim - from_ + to_ + Number<1>{};
}
@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
}
/**
* \brief Clear tensor. (Only for Vpgr/Sgpr)
*
* \param tensor Tensor to be cleared.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
__host__ __device__ void
clear(Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
static_assert(
!Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>::IsDynamicBuffer);
return tensor.GetBuffer().Clear();
}
/**
* \brief Get Tensor Layout.
*