mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 00:39:02 +00:00
@@ -10,7 +10,7 @@ template <index_t NDimHidden, typename VisibleDimensionIds>
|
|||||||
struct TensorCoordinate;
|
struct TensorCoordinate;
|
||||||
|
|
||||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||||
struct TensorCoordinateIterator;
|
struct TensorCoordinateStep;
|
||||||
|
|
||||||
// Transforms: Tuple<transforms...>
|
// Transforms: Tuple<transforms...>
|
||||||
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
||||||
@@ -252,17 +252,16 @@ struct TensorCoordinate
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||||
struct TensorCoordinateIterator
|
struct TensorCoordinateStep
|
||||||
{
|
{
|
||||||
// TODO make these private
|
// TODO make these private
|
||||||
using VisibleIndex = MultiIndex<NDimVisible>;
|
using VisibleIndex = MultiIndex<NDimVisible>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
__host__ __device__ constexpr TensorCoordinateIterator() = default;
|
__host__ __device__ constexpr TensorCoordinateStep() = default;
|
||||||
|
|
||||||
__host__
|
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
|
||||||
__device__ constexpr TensorCoordinateIterator(const VisibleIndex& idx_diff_visible,
|
const MultiIndex<NTransform>& do_transforms)
|
||||||
const MultiIndex<NTransform>& do_transforms)
|
|
||||||
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens
|
|||||||
// UpdateLowerIndexHack: Sequence<...>
|
// UpdateLowerIndexHack: Sequence<...>
|
||||||
// HACK: control UpdateLowerIndex
|
// HACK: control UpdateLowerIndex
|
||||||
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
||||||
__host__ __device__ constexpr auto make_tensor_coordinate_iterator(
|
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||||
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack)
|
const VisibleIndex& idx_diff_visible,
|
||||||
|
UpdateLowerIndexHack)
|
||||||
{
|
{
|
||||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||||
"wrong! # of dimension inconsistent");
|
"wrong! # of dimension inconsistent");
|
||||||
@@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator(
|
|||||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||||
});
|
});
|
||||||
|
|
||||||
return TensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{
|
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
|
||||||
idx_diff_visible, do_transforms};
|
do_transforms};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorDesc, typename VisibleIndex>
|
template <typename TensorDesc, typename VisibleIndex>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||||
make_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible)
|
const VisibleIndex& idx_diff_visible)
|
||||||
{
|
{
|
||||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator>
|
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
|
||||||
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
|
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||||
TensorCoord& coord,
|
TensorCoord& coord,
|
||||||
const TensorCoordIterator& coord_iterator)
|
const TensorCoordStep& coord_step)
|
||||||
{
|
{
|
||||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||||
@@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
|
|||||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||||
|
|
||||||
// initialize visible index diff
|
// initialize visible index diff
|
||||||
set_container_subset(idx_diff_hidden,
|
set_container_subset(
|
||||||
TensorDesc::GetVisibleDimensionIds(),
|
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
|
||||||
coord_iterator.GetVisibleIndexDiff());
|
|
||||||
|
|
||||||
// this is what needs to be updated
|
// this is what needs to be updated
|
||||||
auto& idx_hidden = coord.GetHiddenIndex();
|
auto& idx_hidden = coord.GetHiddenIndex();
|
||||||
@@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
|
|||||||
auto idx_hidden_pick_visible =
|
auto idx_hidden_pick_visible =
|
||||||
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
||||||
|
|
||||||
idx_hidden_pick_visible += coord_iterator.GetIndexDiff();
|
idx_hidden_pick_visible += coord_step.GetIndexDiff();
|
||||||
|
|
||||||
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
||||||
|
|
||||||
// update rest of hidden index
|
// update rest of hidden index
|
||||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||||
if(coord_iterator.do_transforms_[itran])
|
if(coord_step.do_transforms_[itran])
|
||||||
{
|
{
|
||||||
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
||||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||||
@@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
|
|||||||
MultiIndex<dims_low.Size()> idx_diff_low;
|
MultiIndex<dims_low.Size()> idx_diff_low;
|
||||||
|
|
||||||
// HACK: control UpdateLowerIndex for Merge using hack
|
// HACK: control UpdateLowerIndex for Merge using hack
|
||||||
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran);
|
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
|
||||||
|
|
||||||
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
||||||
|
|
||||||
@@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate(
|
|||||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||||
|
|
||||||
template <typename TensorDesc>
|
template <typename TensorDesc>
|
||||||
using TensorCoordinateIterator_t = decltype(make_tensor_coordinate_iterator(
|
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
|
||||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
@@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& step,
|
const Index& step,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.MoveSrcSliceWindow(
|
threadwise_transfer_.MoveSrcSliceWindow(
|
||||||
src_desc, step, src_move_slice_window_iterator_hack);
|
src_desc, step, src_move_slice_window_step_hack);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& step,
|
const Index& step,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.MoveSrcSliceWindow(
|
threadwise_transfer_.MoveSrcSliceWindow(
|
||||||
src_desc, step, src_move_slice_window_iterator_hack);
|
src_desc, step, src_move_slice_window_step_hack);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -84,11 +84,11 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
@@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
|||||||
// LDS double buffer: preload data into LDS
|
// LDS double buffer: preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||||
@@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
|||||||
// even iteration
|
// even iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||||
@@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
|||||||
// odd iteration
|
// odd iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
|||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS double buffer: load last data from device mem
|
// LDS double buffer: load last data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
CGridIteratorHacks{});
|
CGridStepHacks{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -145,11 +145,11 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
@@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
|||||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||||
|
|
||||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||||
|
|
||||||
// hack to control index calculation when move slice window for A and B matrix for
|
// hack to control index calculation when move slice window for A and B matrix for
|
||||||
// threadwise copy
|
// threadwise copy
|
||||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||||
AGridMoveSliceWindowIteratorHacks{};
|
AGridMoveSliceWindowStepHacks{};
|
||||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||||
BGridMoveSliceWindowIteratorHacks{};
|
BGridMoveSliceWindowStepHacks{};
|
||||||
|
|
||||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||||
@@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
|||||||
// LDS double buffer: preload data into LDS
|
// LDS double buffer: preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||||
@@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
|||||||
do
|
do
|
||||||
{
|
{
|
||||||
// even iteration
|
// even iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(
|
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||||
a_k_m0_m1_grid_desc,
|
a_block_slice_copy_step,
|
||||||
a_block_slice_copy_step,
|
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(
|
b_block_slice_copy_step,
|
||||||
b_k_n0_n1_grid_desc,
|
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||||
b_block_slice_copy_step,
|
|
||||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||||
@@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
|||||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||||
|
|
||||||
// odd iteration
|
// odd iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(
|
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||||
a_k_m0_m1_grid_desc,
|
a_block_slice_copy_step,
|
||||||
a_block_slice_copy_step,
|
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(
|
b_block_slice_copy_step,
|
||||||
b_k_n0_n1_grid_desc,
|
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||||
b_block_slice_copy_step,
|
|
||||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
|||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS double buffer: load last data from device mem
|
// LDS double buffer: load last data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
CGridIteratorHacks{});
|
CGridStepHacks{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -141,11 +141,11 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
@@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
|||||||
|
|
||||||
// LDS double buffer: preload data into LDS
|
// LDS double buffer: preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||||
@@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
|||||||
// even iteration
|
// even iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
|
||||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||||
@@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
|||||||
// odd iteration
|
// odd iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
|
||||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
|||||||
// LDS double buffer: tail
|
// LDS double buffer: tail
|
||||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(
|
||||||
a_block_slice_copy_step,
|
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
b_blockwise_copy.MoveSrcSliceWindow(
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||||
b_block_slice_copy_step,
|
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS double buffer: load last data from device mem
|
// LDS double buffer: load last data from device mem
|
||||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
CGridIteratorHacks{});
|
CGridStepHacks{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -42,11 +42,11 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGlobalIteratorHacks,
|
typename AGlobalStepHacks,
|
||||||
typename BGlobalIteratorHacks,
|
typename BGlobalStepHacks,
|
||||||
typename CGlobalIteratorHacks,
|
typename CGlobalStepHacks,
|
||||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
typename AGlobalMoveSliceWindowStepHacks,
|
||||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
typename BGlobalMoveSliceWindowStepHacks>
|
||||||
struct GridwiseGemmDlops_km_kn_mn_v3
|
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||||
{
|
{
|
||||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||||
@@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||||
|
|
||||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||||
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
|
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
|
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||||
|
|
||||||
// hack to control index calculation when move slice window for A and B matrix for
|
// hack to control index calculation when move slice window for A and B matrix for
|
||||||
// threadwise copy
|
// threadwise copy
|
||||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack =
|
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||||
AGlobalMoveSliceWindowIteratorHacks{};
|
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
BGlobalMoveSliceWindowStepHacks{};
|
||||||
BGlobalMoveSliceWindowIteratorHacks{};
|
|
||||||
|
|
||||||
// double regsiter buffer for b
|
// double regsiter buffer for b
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
@@ -257,14 +256,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
|
|
||||||
// LDS double buffer: preload data
|
// LDS double buffer: preload data
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
|
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||||
|
|
||||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||||
b_global_buf,
|
b_global_buf,
|
||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_even_buf,
|
b_thread_even_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||||
}
|
}
|
||||||
@@ -288,7 +287,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_odd_buf,
|
b_thread_odd_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||||
@@ -304,7 +303,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_even_buf,
|
b_thread_even_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||||
@@ -327,7 +326,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_odd_buf,
|
b_thread_odd_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||||
@@ -346,7 +345,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
// output: register to global memory
|
// output: register to global memory
|
||||||
{
|
{
|
||||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||||
|
|
||||||
const index_t k_thread_data_on_global =
|
const index_t k_thread_data_on_global =
|
||||||
k_block_data_on_global + k_thread_id * KPerThread;
|
k_block_data_on_global + k_thread_id * KPerThread;
|
||||||
@@ -370,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_k_n_ho_wo_global_desc,
|
c_k_n_ho_wo_global_desc,
|
||||||
c_global_buf,
|
c_global_buf,
|
||||||
c_k_n_ho_wo_global_tensor_iterator_hacks);
|
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -126,11 +126,11 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks,
|
typename BGridMoveSliceWindowStepHacks,
|
||||||
bool CAccessOrderMRepeatNRepeat>
|
bool CAccessOrderMRepeatNRepeat>
|
||||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||||
{
|
{
|
||||||
@@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||||
|
|
||||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||||
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
|
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||||
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
|
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
|
||||||
|
|
||||||
// hack to control index calculation when move slice window for A and B matrix for
|
// hack to control index calculation when move slice window for A and B matrix for
|
||||||
// threadwise copy
|
// threadwise copy
|
||||||
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
|
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
|
||||||
AGridMoveSliceWindowIteratorHacks{};
|
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||||
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
|
|
||||||
BGridMoveSliceWindowIteratorHacks{};
|
|
||||||
|
|
||||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||||
@@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
|
|
||||||
// preload data into LDS
|
// preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
|
||||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||||
@@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
a_k0_m_k1_grid_move_slice_window_iterator_hack);
|
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
b_k0_n_k1_grid_move_slice_window_iterator_hack);
|
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||||
|
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
|
||||||
|
|
||||||
block_sync_lds();
|
block_sync_lds();
|
||||||
|
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
|
||||||
|
|
||||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||||
|
|
||||||
@@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
const index_t n_thread_data_on_grid =
|
const index_t n_thread_data_on_grid =
|
||||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||||
|
|
||||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||||
@@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_blk_buf_,
|
c_blk_buf_,
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
@@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
const index_t n_thread_data_on_grid =
|
const index_t n_thread_data_on_grid =
|
||||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||||
|
|
||||||
auto c_thread_copy =
|
auto c_thread_copy =
|
||||||
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||||
@@ -610,7 +604,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
|
|
||||||
return c_thread_idx_;
|
return c_thread_idx_;
|
||||||
};
|
};
|
||||||
@@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||||
@@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||||
@@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||||
@@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ namespace ck {
|
|||||||
// 1. Desc is known at compile-time
|
// 1. Desc is known at compile-time
|
||||||
// 2. Buffer is StaticBuffer
|
// 2. Buffer is StaticBuffer
|
||||||
// 3. OriginIdx is known at compile-time
|
// 3. OriginIdx is known at compile-time
|
||||||
// 4. use #-iterator
|
// 4. use #-step
|
||||||
template <typename Data,
|
template <typename Data,
|
||||||
typename Desc,
|
typename Desc,
|
||||||
typename SliceLengths,
|
typename SliceLengths,
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
|
|
||||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||||
|
|
||||||
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
||||||
const Index& dst_slice_origin_idx)
|
const Index& dst_slice_origin_idx)
|
||||||
@@ -84,13 +84,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
template <typename SrcSliceOriginIdx,
|
template <typename SrcSliceOriginIdx,
|
||||||
typename SrcBuffer,
|
typename SrcBuffer,
|
||||||
typename DstBuffer,
|
typename DstBuffer,
|
||||||
typename DstIteratorHacks>
|
typename DstStepHacks>
|
||||||
__device__ void Run(const SrcDesc&,
|
__device__ void Run(const SrcDesc&,
|
||||||
const SrcSliceOriginIdx&,
|
const SrcSliceOriginIdx&,
|
||||||
const SrcBuffer& src_buf,
|
const SrcBuffer& src_buf,
|
||||||
const DstDesc& dst_desc,
|
const DstDesc& dst_desc,
|
||||||
DstBuffer& dst_buf,
|
DstBuffer& dst_buf,
|
||||||
const DstIteratorHacks& dst_iterator_hacks)
|
const DstStepHacks& dst_step_hacks)
|
||||||
{
|
{
|
||||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! SrcDesc need to known at compile-time");
|
"wrong! SrcDesc need to known at compile-time");
|
||||||
@@ -127,31 +127,31 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
constexpr auto ordered_access_lengths =
|
constexpr auto ordered_access_lengths =
|
||||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto dst_forward_iterators = generate_tuple(
|
const auto dst_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto dst_backward_iterators = generate_tuple(
|
const auto dst_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -236,12 +236,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -250,10 +250,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
// move dst coordinate back to slice origin (or not)
|
// move dst coordinate back to slice origin (or not)
|
||||||
if constexpr(DstResetCoordinateAfterRun)
|
if constexpr(DstResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto dst_reset_iterator =
|
const auto dst_reset_step =
|
||||||
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||||
|
|
||||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,11 +268,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||||
|
|
||||||
constexpr auto dst_iterator_hacks =
|
constexpr auto dst_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
|
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||||
@@ -345,7 +345,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
|||||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||||
|
|
||||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
@@ -382,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
|
|
||||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin_idx)
|
const Index& src_slice_origin_idx)
|
||||||
@@ -400,13 +400,13 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
template <typename SrcBuffer,
|
template <typename SrcBuffer,
|
||||||
typename DstBuffer,
|
typename DstBuffer,
|
||||||
typename DstSliceOriginIdx,
|
typename DstSliceOriginIdx,
|
||||||
typename SrcIteratorHacks>
|
typename SrcStepHacks>
|
||||||
__device__ void Run(const SrcDesc& src_desc,
|
__device__ void Run(const SrcDesc& src_desc,
|
||||||
const SrcBuffer& src_buf,
|
const SrcBuffer& src_buf,
|
||||||
const DstDesc&,
|
const DstDesc&,
|
||||||
const DstSliceOriginIdx&,
|
const DstSliceOriginIdx&,
|
||||||
DstBuffer& dst_buf,
|
DstBuffer& dst_buf,
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
const SrcStepHacks& src_step_hacks)
|
||||||
{
|
{
|
||||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! DstDesc need to known at compile-time");
|
"wrong! DstDesc need to known at compile-time");
|
||||||
@@ -441,31 +441,31 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
constexpr auto ordered_access_lengths =
|
constexpr auto ordered_access_lengths =
|
||||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto src_forward_iterators = generate_tuple(
|
const auto src_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto src_backward_iterators = generate_tuple(
|
const auto src_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -548,12 +548,12 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]);
|
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]);
|
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -562,10 +562,10 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
// move src coordinate back to slice origin (or not)
|
// move src coordinate back to slice origin (or not)
|
||||||
if constexpr(SrcResetCoordinateAfterRun)
|
if constexpr(SrcResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto src_reset_iterator =
|
const auto src_reset_step =
|
||||||
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -580,11 +580,11 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||||
|
|
||||||
constexpr auto src_iterator_hacks =
|
constexpr auto src_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||||
@@ -657,7 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
@@ -699,8 +699,8 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin,
|
const Index& src_slice_origin,
|
||||||
@@ -724,10 +724,9 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -755,31 +754,31 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
constexpr auto ordered_src_access_lengths =
|
constexpr auto ordered_src_access_lengths =
|
||||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto src_forward_iterators = generate_tuple(
|
const auto src_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto src_backward_iterators = generate_tuple(
|
const auto src_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -861,12 +860,12 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -875,17 +874,16 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
// move src coordinate back to slice origin (or not)
|
// move src coordinate back to slice origin (or not)
|
||||||
if constexpr(SrcResetCoordinateAfterRun)
|
if constexpr(SrcResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto src_reset_iterator =
|
const auto src_reset_step =
|
||||||
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer, typename DstIteratorHacks>
|
template <typename DstBuffer, typename DstStepHacks>
|
||||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
__device__ void
|
||||||
DstBuffer& dst_buf,
|
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||||
const DstIteratorHacks& dst_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -913,35 +911,31 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
constexpr auto ordered_dst_access_lengths =
|
constexpr auto ordered_dst_access_lengths =
|
||||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto dst_forward_iterators = generate_tuple(
|
const auto dst_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto forward_iterator = make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||||
|
|
||||||
return forward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto dst_backward_iterators = generate_tuple(
|
const auto dst_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto backward_iterator = make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||||
|
|
||||||
return backward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -1025,12 +1019,12 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -1039,10 +1033,10 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
// move dst coordinate back to slice origin (or not)
|
// move dst coordinate back to slice origin (or not)
|
||||||
if constexpr(DstResetCoordinateAfterRun)
|
if constexpr(DstResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto dst_reset_iterator =
|
const auto dst_reset_step =
|
||||||
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||||
|
|
||||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1053,11 +1047,11 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||||
|
|
||||||
constexpr auto src_iterator_hacks =
|
constexpr auto src_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer>
|
template <typename DstBuffer>
|
||||||
@@ -1067,11 +1061,11 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||||
|
|
||||||
constexpr auto dst_iterator_hacks =
|
constexpr auto dst_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||||
@@ -1204,17 +1198,17 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin_step_idx,
|
const Index& src_slice_origin_step_idx,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||||
const auto adjusted_step_idx =
|
const auto adjusted_step_idx =
|
||||||
@@ -1222,8 +1216,8 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(
|
const auto adjusted_step = make_tensor_coordinate_step(
|
||||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
@@ -1237,7 +1231,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||||
|
|
||||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
@@ -1260,7 +1254,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
// 2. SrcBuffer is DynamicBuffer
|
// 2. SrcBuffer is DynamicBuffer
|
||||||
// 3. src_ref_idx is known at run-time
|
// 3. src_ref_idx is known at run-time
|
||||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||||
// 5. use #-iterator
|
// 5. use #-step
|
||||||
// 2. dst:
|
// 2. dst:
|
||||||
// 1. DstDesc is known at compile-time
|
// 1. DstDesc is known at compile-time
|
||||||
// 2. DstBuffer is StaticBuffer
|
// 2. DstBuffer is StaticBuffer
|
||||||
@@ -1287,7 +1281,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
|||||||
|
|
||||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||||
@@ -1386,12 +1380,12 @@ struct ThreadwiseTensorSliceTransfer_v4
|
|||||||
constexpr auto src_ref_to_data_disp_idx =
|
constexpr auto src_ref_to_data_disp_idx =
|
||||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||||
|
|
||||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
constexpr auto src_ref_to_data_disp_coord_step =
|
||||||
make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||||
|
|
||||||
auto src_data_coord = src_ref_coord_;
|
auto src_data_coord = src_ref_coord_;
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||||
|
|
||||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||||
|
|
||||||
@@ -1431,7 +1425,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
|||||||
constexpr auto src_desc = SrcDesc{};
|
constexpr auto src_desc = SrcDesc{};
|
||||||
|
|
||||||
const auto src_slice_move_step_iter =
|
const auto src_slice_move_step_iter =
|
||||||
make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx));
|
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||||
|
|
||||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
using DstCoordIterator = decltype(make_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin,
|
const Index& src_slice_origin,
|
||||||
@@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
constexpr auto ordered_src_access_lengths =
|
constexpr auto ordered_src_access_lengths =
|
||||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto src_forward_iterators = generate_tuple(
|
const auto src_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto src_backward_iterators = generate_tuple(
|
const auto src_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
// move src coordinate back to slice origin (or not)
|
// move src coordinate back to slice origin (or not)
|
||||||
if constexpr(SrcResetCoordinateAfterRun)
|
if constexpr(SrcResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto src_reset_iterator =
|
const auto src_reset_step =
|
||||||
make_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer, typename DstIteratorHacks>
|
template <typename DstBuffer, typename DstStepHacks>
|
||||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
__device__ void
|
||||||
DstBuffer& dst_buf,
|
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||||
const DstIteratorHacks& dst_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
constexpr auto ordered_dst_access_lengths =
|
constexpr auto ordered_dst_access_lengths =
|
||||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto dst_forward_iterators = generate_tuple(
|
const auto dst_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto forward_iterator = make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||||
|
|
||||||
return forward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto dst_backward_iterators = generate_tuple(
|
const auto dst_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto backward_iterator = make_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||||
|
|
||||||
return backward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
// move dst coordinate back to slice origin (or not)
|
// move dst coordinate back to slice origin (or not)
|
||||||
if constexpr(DstResetCoordinateAfterRun)
|
if constexpr(DstResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto dst_reset_iterator =
|
const auto dst_reset_step =
|
||||||
make_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||||
|
|
||||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||||
|
|
||||||
constexpr auto src_iterator_hacks =
|
constexpr auto src_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer>
|
template <typename DstBuffer>
|
||||||
@@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||||
|
|
||||||
constexpr auto dst_iterator_hacks =
|
constexpr auto dst_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||||
@@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin_step_idx,
|
const Index& src_slice_origin_step_idx,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||||
const auto adjusted_step_idx =
|
const auto adjusted_step_idx =
|
||||||
@@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(
|
const auto adjusted_step = make_tensor_coordinate_step(
|
||||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
@@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||||
|
|
||||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
@@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
// 2. SrcBuffer is DynamicBuffer
|
// 2. SrcBuffer is DynamicBuffer
|
||||||
// 3. src_ref_idx is known at run-time
|
// 3. src_ref_idx is known at run-time
|
||||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||||
// 5. use #-iterator
|
// 5. use #-step
|
||||||
// 2. dst:
|
// 2. dst:
|
||||||
// 1. DstDesc is known at compile-time
|
// 1. DstDesc is known at compile-time
|
||||||
// 2. DstBuffer is StaticBuffer
|
// 2. DstBuffer is StaticBuffer
|
||||||
@@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
|||||||
|
|
||||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||||
@@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
|||||||
constexpr auto src_ref_to_data_disp_idx =
|
constexpr auto src_ref_to_data_disp_idx =
|
||||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||||
|
|
||||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
constexpr auto src_ref_to_data_disp_coord_step =
|
||||||
make_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||||
|
|
||||||
auto src_data_coord = src_ref_coord_;
|
auto src_data_coord = src_ref_coord_;
|
||||||
|
|
||||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||||
|
|
||||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||||
|
|
||||||
@@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
|||||||
constexpr auto src_desc = SrcDesc{};
|
constexpr auto src_desc = SrcDesc{};
|
||||||
|
|
||||||
const auto src_slice_move_step_iter =
|
const auto src_slice_move_step_iter =
|
||||||
make_tensor_coordinate_iterator(src_desc, to_multi_index(src_slice_move_step_idx));
|
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||||
|
|
||||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
|||||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
using BGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||||
@@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
using GridwiseGemm =
|
using GridwiseGemm =
|
||||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||||
@@ -184,11 +184,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||||
@@ -249,16 +249,16 @@ extern "C" __global__ void
|
|||||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
using BGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||||
@@ -266,21 +266,21 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
using GridwiseGemm =
|
using GridwiseGemm =
|
||||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||||
@@ -320,11 +320,11 @@ extern "C" __global__ void
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||||
|
|||||||
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
|||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(
|
using AGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
using BGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
using GridwiseGemm =
|
using GridwiseGemm =
|
||||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
false>;
|
false>;
|
||||||
|
|
||||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
@@ -243,12 +243,12 @@ extern "C" __global__ void
|
|||||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(
|
using AGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
using BGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -256,25 +256,25 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||||
@@ -316,11 +316,11 @@ extern "C" __global__ void
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
false>;
|
false>;
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||||
|
|||||||
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
|||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
using AGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using GridwiseGemm =
|
using GridwiseGemm =
|
||||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
false>;
|
false>;
|
||||||
|
|
||||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
@@ -247,12 +247,12 @@ extern "C" __global__ void
|
|||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
using AGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -260,25 +260,25 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using GridwiseGemm =
|
using GridwiseGemm =
|
||||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
@@ -316,11 +316,11 @@ extern "C" __global__ void
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
false>;
|
false>;
|
||||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
|||||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
using AGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||||
@@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||||
@@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(
|
using CGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||||
@@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using BGridMoveSliceWindowIteratorHacks =
|
using BGridMoveSliceWindowStepHacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using GridwiseContraction =
|
using GridwiseContraction =
|
||||||
@@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||||
{
|
{
|
||||||
@@ -254,7 +254,7 @@ extern "C" __global__ void
|
|||||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
using AGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||||
@@ -266,7 +266,7 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||||
@@ -278,7 +278,7 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(
|
using CGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||||
@@ -294,9 +294,9 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using BGridMoveSliceWindowIteratorHacks =
|
using BGridMoveSliceWindowStepHacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using GridwiseContraction =
|
using GridwiseContraction =
|
||||||
@@ -334,11 +334,11 @@ extern "C" __global__ void
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
|||||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -215,7 +215,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||||
|
|
||||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -223,7 +223,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||||
|
|
||||||
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
|
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
||||||
@@ -243,10 +243,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -287,11 +287,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
||||||
6,
|
6,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false // CAccessOrderMRepeatNRepeat
|
false // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
@@ -299,11 +299,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
|||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
in_gemmm_gemmn_grid_desc,
|
in_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
out_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
in_m0_m1_m2_n_grid_iterator_hacks,
|
in_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
|||||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -187,7 +187,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -195,7 +195,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||||
|
|
||||||
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
|
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||||
@@ -215,10 +215,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
||||||
|
|
||||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -263,11 +263,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
|||||||
#endif
|
#endif
|
||||||
7,
|
7,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
true // CAccessOrderMRepeatNRepeat
|
true // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
@@ -275,11 +275,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
|||||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
in_gemmm_gemmn_grid_desc,
|
in_gemmm_gemmn_grid_desc,
|
||||||
out_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
in_m0_m1_m2_n_grid_iterator_hacks,
|
in_m0_m1_m2_n_grid_step_hacks,
|
||||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
in_right_pads);
|
in_right_pads);
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -99,7 +99,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||||
@@ -107,7 +107,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -121,10 +121,10 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||||
@@ -171,22 +171,22 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||||
5, // CThreadTransferSrcDstVectorDim
|
5, // CThreadTransferSrcDstVectorDim
|
||||||
GemmCThreadTransferDstScalarPerVector_N11,
|
GemmCThreadTransferDstScalarPerVector_N11,
|
||||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
|
||||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
|
||||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||||
wei_gemmk_gemmm_grid_desc,
|
wei_gemmk_gemmm_grid_desc,
|
||||||
in_gemmk_gemmn_grid_desc,
|
in_gemmk_gemmn_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
|
||||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
|
||||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
float perf = static_cast<float>(calculate_convolution_flops(
|
float perf = static_cast<float>(calculate_convolution_flops(
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||||
@@ -165,7 +165,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||||
@@ -175,7 +175,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||||
|
|
||||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||||
@@ -189,10 +189,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -231,22 +231,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||||
5, // CThreadTransferSrcDstVectorDim
|
5, // CThreadTransferSrcDstVectorDim
|
||||||
GemmCThreadTransferDstScalarPerVector_N11,
|
GemmCThreadTransferDstScalarPerVector_N11,
|
||||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
|
||||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
|
||||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
|
||||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -92,12 +92,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -123,10 +123,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -167,22 +167,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||||
7,
|
7,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
float perf = static_cast<float>(calculate_convolution_flops(
|
float perf = static_cast<float>(calculate_convolution_flops(
|
||||||
|
|||||||
@@ -121,12 +121,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -134,7 +134,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -144,10 +144,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -187,22 +187,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<2, 3, 0, 1>,
|
Sequence<2, 3, 0, 1>,
|
||||||
2,
|
2,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -182,12 +182,12 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -195,7 +195,7 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -213,10 +213,10 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -256,11 +256,11 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||||
6,
|
6,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false // CAccessOrderMRepeatNRepeat
|
false // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
@@ -268,11 +268,11 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
|||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||||
@@ -241,7 +241,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||||
@@ -249,7 +249,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||||
@@ -267,10 +267,10 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -311,11 +311,11 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
|||||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||||
7,
|
7,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false // CAccessOrderMRepeatNRepeat
|
false // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
@@ -323,11 +323,11 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
|||||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
in_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_grid_iterator_hacks =
|
constexpr auto wei_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||||
@@ -142,7 +142,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||||
|
|
||||||
constexpr auto in_grid_iterator_hacks = make_tuple(
|
constexpr auto in_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||||
@@ -154,7 +154,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||||
|
|
||||||
constexpr auto out_grid_iterator_hacks = make_tuple(
|
constexpr auto out_grid_step_hacks = make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||||
@@ -170,9 +170,9 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||||
|
|
||||||
constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
@@ -211,22 +211,22 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||||
5, // CThreadTransferSrcDstVectorDim
|
5, // CThreadTransferSrcDstVectorDim
|
||||||
CThreadTransferDstScalarPerVector_BN1,
|
CThreadTransferDstScalarPerVector_BN1,
|
||||||
decltype(wei_grid_iterator_hacks),
|
decltype(wei_grid_step_hacks),
|
||||||
decltype(in_grid_iterator_hacks),
|
decltype(in_grid_step_hacks),
|
||||||
decltype(out_grid_iterator_hacks),
|
decltype(out_grid_step_hacks),
|
||||||
decltype(wei_grid_move_slice_window_iterator_hacks),
|
decltype(wei_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_grid_move_slice_window_iterator_hacks)>(
|
decltype(in_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||||
wei_grid_iterator_hacks,
|
wei_grid_step_hacks,
|
||||||
in_grid_iterator_hacks,
|
in_grid_step_hacks,
|
||||||
out_grid_iterator_hacks,
|
out_grid_step_hacks,
|
||||||
wei_grid_move_slice_window_iterator_hacks,
|
wei_grid_move_slice_window_step_hacks,
|
||||||
in_grid_move_slice_window_iterator_hacks,
|
in_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
float perf = static_cast<float>(calculate_convolution_flops(
|
float perf = static_cast<float>(calculate_convolution_flops(
|
||||||
|
|||||||
@@ -39,11 +39,11 @@ template <ck::index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
__host__ float
|
__host__ float
|
||||||
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||||
const FloatAB* p_b_grid,
|
const FloatAB* p_b_grid,
|
||||||
@@ -51,11 +51,11 @@ driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -104,11 +104,11 @@ driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||||
|
|
||||||
|
|||||||
@@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||||
constexpr auto a_e_k_global_iterator_hacks =
|
constexpr auto a_e_k_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
@@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||||
|
|
||||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||||
// hack for NKHW format
|
// hack for NKHW format
|
||||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
Sequence<0, 2, 3, 1>,
|
Sequence<0, 2, 3, 1>,
|
||||||
0,
|
0,
|
||||||
CThreadTransferDstScalarPerVector_W,
|
CThreadTransferDstScalarPerVector_W,
|
||||||
decltype(a_e_k_global_iterator_hacks),
|
decltype(a_e_k_global_step_hacks),
|
||||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||||
|
|
||||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||||
|
|
||||||
|
|||||||
@@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||||
constexpr auto a_e_k_global_iterator_hacks =
|
constexpr auto a_e_k_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
@@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||||
|
|
||||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||||
// hack for NKHW format
|
// hack for NKHW format
|
||||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
Sequence<0, 2, 3, 1>,
|
Sequence<0, 2, 3, 1>,
|
||||||
0,
|
0,
|
||||||
CThreadTransferDstScalarPerVector_W,
|
CThreadTransferDstScalarPerVector_W,
|
||||||
decltype(a_e_k_global_iterator_hacks),
|
decltype(a_e_k_global_step_hacks),
|
||||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||||
|
|
||||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||||
|
|
||||||
|
|||||||
@@ -43,22 +43,22 @@ template <ck::index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||||
const FloatAB* p_b_grid,
|
const FloatAB* p_b_grid,
|
||||||
FloatC* p_c_grid,
|
FloatC* p_c_grid,
|
||||||
const AKMGridDesc& a_k_m_grid_desc,
|
const AKMGridDesc& a_k_m_grid_desc,
|
||||||
const BKNGridDesc& b_k_n_grid_desc,
|
const BKNGridDesc& b_k_n_grid_desc,
|
||||||
const CMNGridDesc& c_m_n_grid_desc,
|
const CMNGridDesc& c_m_n_grid_desc,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -109,11 +109,11 @@ __host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||||
|
|||||||
@@ -39,22 +39,22 @@ template <ck::index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||||
const FloatAB* p_b_grid,
|
const FloatAB* p_b_grid,
|
||||||
FloatC* p_c_grid,
|
FloatC* p_c_grid,
|
||||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||||
const CMNGridDesc& c_m_n_grid_desc,
|
const CMNGridDesc& c_m_n_grid_desc,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -102,11 +102,11 @@ __host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||||
|
|||||||
@@ -41,11 +41,11 @@ template <ck::index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks,
|
typename BGridMoveSliceWindowStepHacks,
|
||||||
bool CAccessOrderMRepeatNRepeat>
|
bool CAccessOrderMRepeatNRepeat>
|
||||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||||
const FloatAB* p_b_grid,
|
const FloatAB* p_b_grid,
|
||||||
@@ -53,11 +53,11 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
|||||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||||
const CMNGridDesc& c_m_n_grid_desc,
|
const CMNGridDesc& c_m_n_grid_desc,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -103,11 +103,11 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
CAccessOrderMRepeatNRepeat>;
|
CAccessOrderMRepeatNRepeat>;
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user