mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add blockwise gemm to ck wrapper (#1139)
* Add blockwise gemm to ck wrapper * Add blockwise gemm traits * Disable test_gemm for non xdl devices * Fixes * Add c layout descritpions
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
@@ -19,9 +20,9 @@ namespace wrapper {
|
||||
* \brief Memory type, allowed members:
|
||||
* - Generic,
|
||||
* - Global,
|
||||
* - LDS,
|
||||
* - SGPR,
|
||||
* - VGPR,
|
||||
* - Lds,
|
||||
* - Sgpr,
|
||||
* - Vgpr,
|
||||
*/
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
@@ -52,12 +53,8 @@ struct Slice
|
||||
__host__ __device__ constexpr auto range(const T& dim) const
|
||||
{
|
||||
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
|
||||
is_same_v<T, index_t>)
|
||||
is_same_v<std::remove_const_t<T>, index_t>)
|
||||
{
|
||||
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
|
||||
{
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
if(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + 1;
|
||||
@@ -70,9 +67,10 @@ struct Slice
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
|
||||
static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
|
||||
(ToType{} < 0 || ToType{} > FromType{}),
|
||||
"Invalid range");
|
||||
if constexpr(to_ < 0)
|
||||
if constexpr(ToType{} < 0)
|
||||
{
|
||||
return dim - from_ + to_ + Number<1>{};
|
||||
}
|
||||
@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
|
||||
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Clear tensor. (Only for Vpgr/Sgpr)
|
||||
*
|
||||
* \param tensor Tensor to be cleared.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename UnrolledDescriptorType>
|
||||
__host__ __device__ void
|
||||
clear(Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
|
||||
{
|
||||
static_assert(
|
||||
!Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>::IsDynamicBuffer);
|
||||
return tensor.GetBuffer().Clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor Layout.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user