mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
rework sequence
This commit is contained in:
@@ -1,30 +1,11 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
|
||||
template <class PreviousStrides, class RemainLengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
|
||||
{
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_length = RemainLengths{}.Back();
|
||||
constexpr index_t current_stride = current_length * previous_stride;
|
||||
|
||||
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
|
||||
RemainLengths{}.PopBack());
|
||||
}
|
||||
|
||||
template <class PreviousStrides, index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
|
||||
{
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_stride = L1 * previous_stride;
|
||||
|
||||
return PreviousStrides{}.PushFront(Number<current_stride>{});
|
||||
}
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
|
||||
{
|
||||
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
|
||||
return reverse_inclusive_scan_sequence(Lengths{}.PopFront().PushBack(Number<1>{}),
|
||||
std::multiplies<index_t>{});
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
@@ -57,7 +38,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor;
|
||||
using Type = ConstantTensorDescriptor;
|
||||
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
|
||||
// folded strides
|
||||
constexpr auto fold_strides =
|
||||
Number<unfold_stride>{} *
|
||||
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
|
||||
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
|
||||
std::multiplies<index_t>{});
|
||||
|
||||
// left and right
|
||||
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
|
||||
|
||||
@@ -9,7 +9,8 @@ struct Sequence
|
||||
|
||||
static constexpr index_t mSize = sizeof...(Is);
|
||||
|
||||
const index_t mData[mSize] = {Is...};
|
||||
const index_t mData[mSize + 1] = {
|
||||
Is..., 0}; // the last element is dummy, to prevent compiler complain on empty Sequence
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return mSize; }
|
||||
|
||||
@@ -39,10 +40,7 @@ struct Sequence
|
||||
assert(false);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Reverse() const
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
__host__ __device__ constexpr auto Reverse() const;
|
||||
|
||||
__host__ __device__ constexpr index_t Front() const { return mData[0]; }
|
||||
|
||||
@@ -73,13 +71,13 @@ struct Sequence
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ constexpr auto Extract(Number<Ns>...) const
|
||||
{
|
||||
return Sequence<Get(Number<Ns>{})...>{};
|
||||
return Sequence<Type{}.Get(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
|
||||
{
|
||||
return Sequence<Get(Number<Ns>{})...>{};
|
||||
return Sequence<Type{}.Get(Number<Ns>{})...>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -89,44 +87,110 @@ struct sequence_merge;
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using Type = Sequence<Xs..., Ys...>;
|
||||
using SeqType = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t NSize, index_t Increment>
|
||||
struct increasing_sequence_gen
|
||||
struct increasing_sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NSizeLeft = NSize / 2;
|
||||
|
||||
using Type =
|
||||
sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type,
|
||||
typename increasing_sequence_gen<IBegin + NSizeLeft * Increment,
|
||||
NSize - NSizeLeft,
|
||||
Increment>::Type>;
|
||||
using SeqType = typename sequence_merge<
|
||||
typename increasing_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
|
||||
typename increasing_sequence_gen_impl<IBegin + NSizeLeft * Increment,
|
||||
NSize - NSizeLeft,
|
||||
Increment>::SeqType>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct increasing_sequence_gen<IBegin, 1, Increment>
|
||||
struct increasing_sequence_gen_impl<IBegin, 1, Increment>
|
||||
{
|
||||
using Type = Sequence<IBegin>;
|
||||
using SeqType = Sequence<IBegin>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct increasing_sequence_gen<IBegin, 0, Increment>
|
||||
struct increasing_sequence_gen_impl<IBegin, 0, Increment>
|
||||
{
|
||||
using Type = Sequence<>;
|
||||
using SeqType = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
struct increasing_sequence_gen
|
||||
{
|
||||
using SeqType =
|
||||
typename increasing_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
__host__ __device__ constexpr auto
|
||||
make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
|
||||
{
|
||||
static_assert(IBegin <= IEnd && Increment > 0, "wrong!");
|
||||
|
||||
constexpr index_t NSize = (IEnd - IBegin) / Increment;
|
||||
|
||||
return increasing_sequence_gen<IBegin, NSize, Increment>{};
|
||||
return typename increasing_sequence_gen<IBegin, IEnd, Increment>::SeqType{};
|
||||
}
|
||||
|
||||
template <class, class>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
|
||||
template <index_t I, index_t... Is, class Reduce>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t I, class Reduce>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
|
||||
{
|
||||
using SeqType = Sequence<I>;
|
||||
};
|
||||
|
||||
template <class, class>
|
||||
struct sequence_extract;
|
||||
|
||||
template <class Seq, index_t... Is>
|
||||
struct sequence_extract<Seq, Sequence<Is...>>
|
||||
{
|
||||
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
|
||||
};
|
||||
|
||||
template <class Seq, index_t I>
|
||||
struct sequence_split
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
|
||||
using range0 = typename increasing_sequence_gen<0, I, 1>::SeqType;
|
||||
using range1 = typename increasing_sequence_gen<I, NSize, 1>::SeqType;
|
||||
|
||||
using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
|
||||
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
|
||||
};
|
||||
|
||||
template <class Seq>
|
||||
struct sequence_reverse
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using SeqType = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
|
||||
typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct sequence_reverse<Sequence<I>>
|
||||
{
|
||||
using SeqType = Sequence<I>;
|
||||
};
|
||||
|
||||
template <index_t I0, index_t I1>
|
||||
struct sequence_reverse<Sequence<I0, I1>>
|
||||
{
|
||||
using SeqType = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
@@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
#if 0 // doesn't compile
|
||||
constexpr auto seq_x = Sequence<Xs...>{};
|
||||
|
||||
#if 0 // doesn't compile
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
|
||||
constexpr auto I = decltype(Iter){};
|
||||
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
|
||||
@@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
#if 0
|
||||
// TODO: for some reason, compiler cannot instantiate this template
|
||||
template <index_t... Is, index_t I>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
|
||||
template <class Seq>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Seq)
|
||||
{
|
||||
static_assert(sizeof...(Is) > 0, "empty Sequence!");
|
||||
return Sequence<Is...>{};
|
||||
static_assert(Seq{}.GetSize() > 0, "empty Sequence!");
|
||||
return sequence_pop_front(Seq{}.Reverse()).Reverse();
|
||||
}
|
||||
#else
|
||||
// TODO: delete these very ugly mess
|
||||
template <index_t I0, index_t I1>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1>)
|
||||
{
|
||||
return Sequence<I0>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2>)
|
||||
{
|
||||
return Sequence<I0, I1>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3>)
|
||||
{
|
||||
return Sequence<I0, I1, I2>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5, index_t I6>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5>{};
|
||||
}
|
||||
|
||||
template <index_t I0,
|
||||
index_t I1,
|
||||
index_t I2,
|
||||
index_t I3,
|
||||
index_t I4,
|
||||
index_t I5,
|
||||
index_t I6,
|
||||
index_t I7>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5, I6>{};
|
||||
}
|
||||
|
||||
template <index_t I0,
|
||||
index_t I1,
|
||||
index_t I2,
|
||||
index_t I3,
|
||||
index_t I4,
|
||||
index_t I5,
|
||||
index_t I6,
|
||||
index_t I7,
|
||||
index_t I8>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7>{};
|
||||
}
|
||||
|
||||
template <index_t I0,
|
||||
index_t I1,
|
||||
index_t I2,
|
||||
index_t I3,
|
||||
index_t I4,
|
||||
index_t I5,
|
||||
index_t I6,
|
||||
index_t I7,
|
||||
index_t I8,
|
||||
index_t I9>
|
||||
__host__ __device__ constexpr auto
|
||||
sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8, I9>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class F, index_t... Xs>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
|
||||
@@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t
|
||||
return Reduce{}(a, I);
|
||||
}
|
||||
|
||||
template <index_t NRemain>
|
||||
struct scan_sequence_impl
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::Reverse() const
|
||||
{
|
||||
template <class ScanedSeq, class RemainSeq, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const
|
||||
{
|
||||
static_assert(RemainSeq{}.GetSize() == NRemain,
|
||||
"wrong! RemainSeq and NRemain not consistent!");
|
||||
|
||||
constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front());
|
||||
constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number<a>{});
|
||||
|
||||
static_if<(NRemain > 1)>{}([&](auto fwd) {
|
||||
return scan_sequence_impl<NRemain - 1>{}(
|
||||
scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{}));
|
||||
}).else_([&](auto fwd) { return fwd(scaned_seq); });
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto scan_sequence(Seq, Reduce)
|
||||
{
|
||||
constexpr auto scaned_seq = Sequence<Seq{}.front()>{};
|
||||
constexpr auto remain_seq = Seq{}.PopFront();
|
||||
|
||||
constexpr index_t remain_size = Seq::GetSize() - 1;
|
||||
|
||||
return scan_sequence_impl<remain_size>{}(scaned_seq, remain_seq, Reduce{});
|
||||
return typename sequence_reverse<Sequence<Is...>>::SeqType{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce)
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
|
||||
{
|
||||
return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse();
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
|
||||
}
|
||||
|
||||
@@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
@@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// being instantiated here
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
@@ -374,38 +370,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"a: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(
|
||||
@@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
}).else_([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
@@ -429,33 +420,34 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"b: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(
|
||||
|
||||
Reference in New Issue
Block a user