mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
try using more constexpr
This commit is contained in:
@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r3, Pascal
|
||||
// for 3x3, 28x28, v1r3, Pascal
|
||||
// for 3x3, 14x14, v1r3, Pascal
|
||||
|
||||
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
|
||||
@@ -18,11 +18,24 @@ struct Array
|
||||
|
||||
__host__ __device__ constexpr index_t GetSize() const { return NSize; }
|
||||
|
||||
__host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }
|
||||
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
|
||||
|
||||
__host__ __device__ auto PushBack(TData x) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData Get(Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr bool Set(Number<I>, TData x)
|
||||
{
|
||||
mData[I] = x;
|
||||
return true; // for constexpr
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto PushBack(TData x) const
|
||||
{
|
||||
Array<TData, NSize + 1> new_array;
|
||||
|
||||
|
||||
@@ -74,7 +74,8 @@ struct ConstantMergedTensorDescriptor
|
||||
return OriginalTensorDesc::GetElementSize();
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
Array<index_t, nOriginalDim> original_multi_id;
|
||||
@@ -98,21 +99,111 @@ struct ConstantMergedTensorDescriptor
|
||||
|
||||
return original_multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
#else
|
||||
template <class OriginalDimsPartial>
|
||||
struct GetOriginalMultiIndexFromMultiIndex_impl1
|
||||
{
|
||||
const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_ref;
|
||||
Array<index_t, nOriginalDim>& original_multi_id_ref;
|
||||
|
||||
__host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl1(
|
||||
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial,
|
||||
Array<index_t, nOriginalDim>& original_multi_id)
|
||||
: original_multi_id_partial_ref(original_multi_id_partial),
|
||||
original_multi_id_ref(original_multi_id)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
constexpr __host__ __device__ bool operator()(Number<I>) const
|
||||
{
|
||||
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
|
||||
|
||||
index_t itmp = original_multi_id_partial_ref.Get(Number<I>{});
|
||||
|
||||
original_multi_id_ref.Set(Number<idim_original>{}, itmp);
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct GetOriginalMultiIndexFromMultiIndex_impl0
|
||||
{
|
||||
const Array<index_t, nDim>& multi_id_ref;
|
||||
Array<index_t, nOriginalDim>& original_multi_id_ref;
|
||||
|
||||
__host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl0(
|
||||
const Array<index_t, nDim>& multi_id, Array<index_t, nOriginalDim>& original_multi_id)
|
||||
: multi_id_ref(multi_id), original_multi_id_ref(original_multi_id)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
constexpr __host__ __device__ bool operator()(Number<IDim>) const
|
||||
{
|
||||
constexpr auto original_dims_partial =
|
||||
std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{});
|
||||
|
||||
// get partial original-multi-id corresponding to this merged dimension
|
||||
const auto original_multi_id_partial =
|
||||
OriginalTensorDesc::Extract(original_dims_partial)
|
||||
.GetMultiIndexFrom1dIndex(multi_id_ref[IDim]);
|
||||
|
||||
static_for<0, original_dims_partial.GetSize(), 1>{}(
|
||||
GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>(
|
||||
original_multi_id_partial, original_multi_id_ref));
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
Array<index_t, nOriginalDim> original_multi_id;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
GetOriginalMultiIndexFromMultiIndex_impl0(multi_id, original_multi_id));
|
||||
|
||||
return original_multi_id;
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
|
||||
|
||||
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// return type is Sequence<...>
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
// not implemented
|
||||
return Sequence<>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths());
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), mod_conv::multiplies<index_t>{})
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), mod_conv::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
@@ -91,8 +92,10 @@ struct ConstantTensorDescriptor
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
@@ -105,9 +108,43 @@ struct ConstantTensorDescriptor
|
||||
|
||||
return offset;
|
||||
}
|
||||
#else
|
||||
template <index_t NSize>
|
||||
struct GetOffsetFromMultiIndex_impl
|
||||
{
|
||||
Array<index_t, NSize>& multi_id_ref;
|
||||
index_t& offset_ref;
|
||||
|
||||
__host__ __device__ constexpr GetOffsetFromMultiIndex_impl(Array<index_t, NSize>& multi_id,
|
||||
index_t& offset)
|
||||
: multi_id_ref(multi_id), offset_ref(offset)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr bool operator()(Number<IDim>) const
|
||||
{
|
||||
offset_ref += multi_id_ref.Get(Number<IDim>{}) * Type::GetStride(Number<IDim>{});
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(GetOffsetFromMultiIndex_impl<NSize>(multi_id, offset));
|
||||
|
||||
return offset;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
|
||||
}
|
||||
@@ -123,7 +160,8 @@ struct ConstantTensorDescriptor
|
||||
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
#if 0
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
@@ -141,8 +179,58 @@ struct ConstantTensorDescriptor
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
#else
|
||||
struct GetMultiIndexFrom1dIndex_impl
|
||||
{
|
||||
using DummyStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
|
||||
|
||||
__host__ __device__ static auto
|
||||
index_t& id_ref;
|
||||
Array<index_t, nDim>& multi_id_ref;
|
||||
|
||||
__host__ __device__ constexpr GetMultiIndexFrom1dIndex_impl(index_t& id,
|
||||
Array<index_t, nDim>& multi_id)
|
||||
: id_ref(id), multi_id_ref(multi_id)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr bool operator()(Number<IDim>) const
|
||||
{
|
||||
constexpr index_t stride = DummyStrides::Get(Number<IDim>{});
|
||||
multi_id_ref.Set(Number<IDim>{}, id_ref / stride);
|
||||
id_ref -= multi_id_ref.Get(Number<IDim>{}) * stride;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(GetMultiIndexFrom1dIndex_impl(id, multi_id));
|
||||
|
||||
index_t itmp = id / dummy_strides.Get(Number<nDim - 1>{});
|
||||
|
||||
multi_id.Set(Number<nDim - 1>{}, itmp);
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// return type is Sequence<...>
|
||||
template<index_t Id>
|
||||
__host__ __device__ static constexpr auto GetMultiIndexFrom1dIndex(Number<Id>)
|
||||
{
|
||||
return inclusive_scan_sequence(f_impl, GetStrides(), Number<Id>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
return multi_id;
|
||||
@@ -278,8 +366,8 @@ struct ConstantTensorDescriptor
|
||||
// folded strides
|
||||
constexpr auto fold_strides =
|
||||
Number<unfold_stride>{} *
|
||||
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
|
||||
mod_conv::multiplies<index_t>{});
|
||||
reverse_inclusive_scan_sequence(
|
||||
fold_intervals.PushBack(Number<1>{}), mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
|
||||
|
||||
@@ -139,31 +139,49 @@ struct arithmetic_sequence_gen
|
||||
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
|
||||
};
|
||||
|
||||
template <class, class>
|
||||
// reverse scan with init
|
||||
template <class, class, index_t>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
|
||||
template <index_t I, index_t... Is, class Reduce>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
|
||||
template <index_t I, index_t... Is, class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;
|
||||
using old_scan =
|
||||
typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::SeqType;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t I, class Reduce>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
|
||||
template <index_t I, class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
|
||||
{
|
||||
using SeqType = Sequence<I>;
|
||||
using SeqType = Sequence<Reduce{}(I, Init)>;
|
||||
};
|
||||
|
||||
template <class Reduce>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce>
|
||||
template <class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
|
||||
{
|
||||
using SeqType = Sequence<>;
|
||||
};
|
||||
|
||||
#if 0
|
||||
// reverse scan with token
|
||||
template <class, class, index_t>
|
||||
struct sequence_reverse_inclusive_token_scan;
|
||||
|
||||
template <index_t I, index_t... Is, class F, index_t Token>
|
||||
struct sequence_reverse_inclusive_token_scan<Sequence<I, Is...>, F, Token>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_token_scan<Sequence<Is...>, F, Token>::SeqType;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class, class>
|
||||
struct sequence_extract;
|
||||
|
||||
@@ -434,16 +452,16 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
return Sequence<f(Xs, Ys, Zs)...>{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::SeqType{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
|
||||
@@ -203,6 +203,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if 0
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
const auto src_thread_data_multi_id_begin =
|
||||
@@ -216,6 +217,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex(
|
||||
clipboard_data_multi_id_begin); // cannot not constexpr, why?
|
||||
#else
|
||||
constexpr auto src_thread_data_multi_id_begin =
|
||||
repeat_multi_id_ * data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id_ * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
struct forwarder
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T operator()(T&& x) const
|
||||
__host__ __device__ constexpr T&& operator()(T&& x) const
|
||||
{
|
||||
return std::forward<T>(x);
|
||||
return static_cast<T&&>(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -76,7 +76,7 @@ template <index_t Iter, index_t Remaining, index_t Increment>
|
||||
struct static_for_impl
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
constexpr __host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
|
||||
static_assert(Increment <= Remaining, "will go out-of-range");
|
||||
@@ -90,7 +90,7 @@ template <index_t Iter, index_t Increment>
|
||||
struct static_for_impl<Iter, 0, Increment>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F) const
|
||||
constexpr __host__ __device__ void operator()(F) const
|
||||
{
|
||||
// no work left, just return
|
||||
return;
|
||||
@@ -102,13 +102,19 @@ template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
constexpr __host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
#if 0
|
||||
static_if<(NBegin < NEnd)>{}(
|
||||
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
|
||||
#else
|
||||
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -235,8 +235,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
|
||||
@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
#if 0
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run_asm(Xs...);
|
||||
|
||||
Reference in New Issue
Block a user