mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
implicit gemm v1r2: only load 1d filter
This commit is contained in:
@@ -8,6 +8,13 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
|
||||
return Sequence<L1, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 3d
|
||||
template <index_t L0, index_t L1, index_t L2>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2>)
|
||||
{
|
||||
return Sequence<L1 * L2, L2, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
@@ -79,6 +86,15 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
|
||||
return Sequence<L1_align, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 3d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr index_t L2_align = Align * ((L2 + Align - 1) / Align);
|
||||
return Sequence<L1 * L2_align, L2_align, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
|
||||
@@ -244,6 +260,22 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1));
|
||||
}
|
||||
else if(ndim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2));
|
||||
}
|
||||
else if(ndim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
Reference in New Issue
Block a user