mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
rename
This commit is contained in:
@@ -10,7 +10,7 @@ template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||
struct TensorCoordinate;
|
||||
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct TensorCoordinateIterator;
|
||||
struct TensorCoordinateStep;
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
||||
@@ -252,17 +252,16 @@ struct TensorCoordinate
|
||||
};
|
||||
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct TensorCoordinateIterator
|
||||
struct TensorCoordinateStep
|
||||
{
|
||||
// TODO make these private
|
||||
using VisibleIndex = MultiIndex<NDimVisible>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr TensorCoordinateIterator() = default;
|
||||
__host__ __device__ constexpr TensorCoordinateStep() = default;
|
||||
|
||||
__host__
|
||||
__device__ constexpr TensorCoordinateIterator(const VisibleIndex& idx_diff_visible,
|
||||
const MultiIndex<NTransform>& do_transforms)
|
||||
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
|
||||
const MultiIndex<NTransform>& do_transforms)
|
||||
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
||||
{
|
||||
}
|
||||
@@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens
|
||||
// UpdateLowerIndexHack: Sequence<...>
|
||||
// HACK: control UpdateLowerIndex
|
||||
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_iterator(
|
||||
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack)
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||
const VisibleIndex& idx_diff_visible,
|
||||
UpdateLowerIndexHack)
|
||||
{
|
||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
@@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator(
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
|
||||
return TensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{
|
||||
idx_diff_visible, do_transforms};
|
||||
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
|
||||
do_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename VisibleIndex>
|
||||
__host__ __device__ constexpr auto
|
||||
make_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible)
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||
const VisibleIndex& idx_diff_visible)
|
||||
{
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
return make_tensor_coordinate_step(
|
||||
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator>
|
||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
|
||||
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
TensorCoord& coord,
|
||||
const TensorCoordIterator& coord_iterator)
|
||||
const TensorCoordStep& coord_step)
|
||||
{
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
@@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize visible index diff
|
||||
set_container_subset(idx_diff_hidden,
|
||||
TensorDesc::GetVisibleDimensionIds(),
|
||||
coord_iterator.GetVisibleIndexDiff());
|
||||
set_container_subset(
|
||||
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.GetHiddenIndex();
|
||||
@@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
|
||||
auto idx_hidden_pick_visible =
|
||||
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
||||
|
||||
idx_hidden_pick_visible += coord_iterator.GetIndexDiff();
|
||||
idx_hidden_pick_visible += coord_step.GetIndexDiff();
|
||||
|
||||
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(coord_iterator.do_transforms_[itran])
|
||||
if(coord_step.do_transforms_[itran])
|
||||
{
|
||||
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||
@@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
|
||||
MultiIndex<dims_low.Size()> idx_diff_low;
|
||||
|
||||
// HACK: control UpdateLowerIndex for Merge using hack
|
||||
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran);
|
||||
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
|
||||
|
||||
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
||||
|
||||
@@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
|
||||
template <typename TensorDesc>
|
||||
using TensorCoordinateIterator_t = decltype(make_tensor_coordinate_iterator(
|
||||
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_iterator_hack);
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_iterator_hack);
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -84,11 +84,11 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
@@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
@@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
@@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
@@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
c_thread_buf,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -145,11 +145,11 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
||||
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||
AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||
BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
@@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
@@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
@@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
@@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
@@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -141,11 +141,11 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
@@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
@@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
@@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
@@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -42,11 +42,11 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
typename AGlobalStepHacks,
|
||||
typename BGlobalStepHacks,
|
||||
typename CGlobalStepHacks,
|
||||
typename AGlobalMoveSliceWindowStepHacks,
|
||||
typename BGlobalMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
@@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
|
||||
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack =
|
||||
AGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
BGlobalMoveSliceWindowStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
@@ -257,14 +256,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||
}
|
||||
@@ -288,7 +287,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||
@@ -304,7 +303,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
@@ -327,7 +326,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
@@ -346,7 +345,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
// output: register to global memory
|
||||
{
|
||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + k_thread_id * KPerThread;
|
||||
@@ -370,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
c_global_buf,
|
||||
c_k_n_ho_wo_global_tensor_iterator_hacks);
|
||||
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -126,11 +126,11 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
@@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
@@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
@@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_iterator_hack);
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_iterator_hack);
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
@@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
@@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
}
|
||||
#else
|
||||
{
|
||||
@@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||
@@ -610,7 +604,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
|
||||
return c_thread_idx_;
|
||||
};
|
||||
@@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
@@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
@@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
@@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||
|
||||
@@ -11,7 +11,7 @@ namespace ck {
|
||||
// 1. Desc is known at compile-time
|
||||
// 2. Buffer is StaticBuffer
|
||||
// 3. OriginIdx is known at compile-time
|
||||
// 4. use #-iterator
|
||||
// 4. use #-step
|
||||
template <typename Data,
|
||||
typename Desc,
|
||||
typename SliceLengths,
|
||||
|
||||
@@ -66,7 +66,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx)
|
||||
@@ -84,13 +84,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstIteratorHacks>
|
||||
typename DstStepHacks>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -127,31 +127,31 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -236,12 +236,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -250,10 +250,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,11 +268,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
|
||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
@@ -345,7 +345,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
@@ -382,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx)
|
||||
@@ -400,13 +400,13 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
template <typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcIteratorHacks>
|
||||
typename SrcStepHacks>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
@@ -441,31 +441,31 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -548,12 +548,12 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -562,10 +562,10 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,11 +580,11 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -657,7 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
@@ -699,8 +699,8 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
@@ -724,10 +724,9 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -755,31 +754,31 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -861,12 +860,12 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -875,17 +874,16 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstIteratorHacks>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
template <typename DstBuffer, typename DstStepHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -913,35 +911,31 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
const auto forward_iterator = make_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
|
||||
return forward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
const auto backward_iterator = make_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
|
||||
return backward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -1025,12 +1019,12 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -1039,10 +1033,10 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1053,11 +1047,11 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
@@ -1067,11 +1061,11 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
||||
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -1204,17 +1198,17 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
@@ -1222,8 +1216,8 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
@@ -1237,7 +1231,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
@@ -1260,7 +1254,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-iterator
|
||||
// 5. use #-step
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
@@ -1287,7 +1281,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
@@ -1386,12 +1380,12 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
||||
make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
||||
constexpr auto src_ref_to_data_disp_coord_step =
|
||||
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
|
||||
@@ -1431,7 +1425,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter =
|
||||
make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
@@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
@@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstIteratorHacks>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
template <typename DstBuffer, typename DstStepHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
const auto forward_iterator = make_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
|
||||
return forward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
const auto backward_iterator = make_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
|
||||
return backward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
@@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
||||
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
@@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
@@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
@@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-iterator
|
||||
// 5. use #-step
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
@@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
@@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
||||
make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
||||
constexpr auto src_ref_to_data_disp_coord_step =
|
||||
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
|
||||
@@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter =
|
||||
make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
@@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
@@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
@@ -184,11 +184,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
@@ -249,16 +249,16 @@ extern "C" __global__ void
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
@@ -266,21 +266,21 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
@@ -320,11 +320,11 @@ extern "C" __global__ void
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
|
||||
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
@@ -243,12 +243,12 @@ extern "C" __global__ void
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -256,25 +256,25 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
@@ -316,11 +316,11 @@ extern "C" __global__ void
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
|
||||
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridIteratorHacks =
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
@@ -247,12 +247,12 @@ extern "C" __global__ void
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridIteratorHacks =
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -260,25 +260,25 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
@@ -316,11 +316,11 @@ extern "C" __global__ void
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
@@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridIteratorHacks =
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
@@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
@@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
@@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks =
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
@@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
@@ -254,7 +254,7 @@ extern "C" __global__ void
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridIteratorHacks =
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
@@ -266,7 +266,7 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
@@ -278,7 +278,7 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
@@ -294,9 +294,9 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks =
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
@@ -334,11 +334,11 @@ extern "C" __global__ void
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
|
||||
Reference in New Issue
Block a user