mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Reorganize files, Part 1 (#119)
* delete obselete files * move files * build * update cmake * update cmake * fix build * reorg examples * update cmake for example and test
This commit is contained in:
@@ -0,0 +1,925 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool BetaIsZero,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)p_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
index_t indexOffset = 0;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(offset) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<index_t>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void
|
||||
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_ws_values_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)in_elementwise_op;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_ws_values_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
// index_t indexOffset = 0;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
src_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
threadwise_src_idx_load.Run(in_grid_desc_m_k,
|
||||
src_global_idx_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
// indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,268 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType* const __restrict__ p_out_global)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_global,
|
||||
p_out_global);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
|
||||
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType* const __restrict__ p_out_global)
|
||||
{
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
|
||||
// reduced output to the global location corresponding to each invariant dimension to get a
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_reduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,514 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename WorkspaceDesc_M_K,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K workspace_desc_m_k,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
workspace_desc_m_k,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
p_src_global,
|
||||
p_ws_values_global,
|
||||
p_ws_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
|
||||
workspace_desc_m_k,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
p_src_global,
|
||||
p_ws_values_global,
|
||||
p_ws_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename WorkspaceDesc_M_K,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer1dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K& workspace_desc_m_k,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_workspace_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_value_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K& workspace_desc_m_k,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ index_t p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_val_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t indexOffset = block_local_id * reduceSizePerBlock;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(offset) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_workspace_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<IndexDataType>,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
|
||||
threadwise_workspace_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_value_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_val_buf);
|
||||
threadwise_workspace_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_index_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,435 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndices(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool BetaIsZero,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
(void)p_indices_global;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
dst_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t indexStart = 0;
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
AccumulationWithIndex::Calculate(accu_value_buf(I),
|
||||
in_thread_buf[offset],
|
||||
accu_index_buf(I),
|
||||
indexStart + J);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexStart += KThreadSliceSize;
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
|
||||
|
||||
threadwise_dst_idx_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,649 @@
|
||||
#ifndef CK_GRIDWISE_BATCHED_GEMM_XDLOPS_V2R3_HPP
|
||||
#define CK_GRIDWISE_BATCHED_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseBatchedGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_G_K0_M_K1,
|
||||
typename BGridDesc_G_K0_N_K1,
|
||||
typename CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_v2r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_G_K0_M_K1 a_grid_desc_g_k0_m_k1,
|
||||
const BGridDesc_G_K0_N_K1 b_grid_desc_g_k0_n_k1,
|
||||
const CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseBatchedGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseBatchedGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_g_k0_m_k1,
|
||||
b_grid_desc_g_k0_n_k1,
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_G_K0_M_K1,
|
||||
typename BGridDesc_G_K0_N_K1,
|
||||
typename CGridDesc_G_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_G_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_G_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct GridwiseBatchedGemm_gk0mk1_gk0nk1_gmn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_g_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(I1, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
|
||||
Number<MPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(I1, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_g_k0_m_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_g_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(I1, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
|
||||
Number<NPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(I1, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_g_k0_n_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto a_block_desc_g_k0_m_k1 =
|
||||
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto K0 = a_block_desc_g_k0_m_k1.GetLength(I1);
|
||||
constexpr auto M = a_block_desc_g_k0_m_k1.GetLength(I2);
|
||||
|
||||
constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
|
||||
a_block_desc_g_k0_m_k1,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(M),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return a_block_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto b_block_desc_g_k0_n_k1 =
|
||||
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto K0 = b_block_desc_g_k0_n_k1.GetLength(I1);
|
||||
constexpr auto N = b_block_desc_g_k0_n_k1.GetLength(I2);
|
||||
|
||||
constexpr auto b_block_desc_k0_n_k1 = transform_tensor_descriptor(
|
||||
b_block_desc_g_k0_n_k1,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return b_block_desc_k0_n_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_g_k0_m_k1 =
|
||||
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_g_k0_n_k1 =
|
||||
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_g_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_g_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_G_K0_M_K1& a_grid_desc_g_k0_m_k1,
|
||||
const BGridDesc_G_K0_N_K1& b_grid_desc_g_k0_n_k1,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
// const auto G = a_grid_desc_g_k0_m_k1.GetLength(I0);
|
||||
const auto K0 = a_grid_desc_g_k0_m_k1.GetLength(I1);
|
||||
const auto M = a_grid_desc_g_k0_m_k1.GetLength(I2);
|
||||
const auto N = b_grid_desc_g_k0_n_k1.GetLength(I2);
|
||||
|
||||
if(!(M == c_grid_desc_g_m_n.GetLength(I1) && N == c_grid_desc_g_m_n.GetLength(I2) &&
|
||||
K0 == b_grid_desc_g_k0_n_k1.GetLength(I1) &&
|
||||
K1 == a_grid_desc_g_k0_m_k1.GetLength(I3) &&
|
||||
K1 == b_grid_desc_g_k0_n_k1.GetLength(I3)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const index_t grid_size = G * (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_g_m_n);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_G_M_N& c_grid_desc_g_m_n, index_t M01, index_t N01)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(G, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_g_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_g_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_G_M_N{}));
|
||||
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_G_M_N{}, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_G_K0_M_K1& a_grid_desc_g_k0_m_k1,
|
||||
const BGridDesc_G_K0_N_K1& b_grid_desc_g_k0_n_k1,
|
||||
const CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_g_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_g_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_grid_desc_g_k0_m_k1.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t g_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_g_k0_m_k1 =
|
||||
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_g_k0_n_k1 =
|
||||
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_G_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_g_k0_m_k1),
|
||||
decltype(a_block_desc_g_k0_m_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_grid_desc_g_k0_m_k1,
|
||||
make_multi_index(g_idx_on_grid, 0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_g_k0_m_k1,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_G_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_g_k0_n_k1),
|
||||
decltype(b_block_desc_g_k0_n_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_g_k0_n_k1,
|
||||
make_multi_index(g_idx_on_grid, 0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_g_k0_n_k1,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_g_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_g_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_g_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc_g_k0_m_k1, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_g_k0_n_k1, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_g_k0_m_k1, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_g_k0_n_k1, b_block_buf);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_g_k0_m_k1, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_g_k0_n_k1, b_block_slice_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc_g_k0_m_k1, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc_g_k0_n_k1, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_g_k0_m_k1, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_g_k0_n_k1, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
// constexpr auto G = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
|
||||
constexpr auto M0 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
|
||||
constexpr auto N0 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
|
||||
constexpr auto M1 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
|
||||
constexpr auto N1 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
|
||||
constexpr auto M2 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
|
||||
constexpr auto M3 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
|
||||
constexpr auto M4 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
|
||||
constexpr auto N2 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I8);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
CElementwiseOperation,
|
||||
Sequence<I1, M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(g_idx_on_grid,
|
||||
m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2]),
|
||||
c_element_op};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_grid_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,659 @@
|
||||
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
|
||||
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
|
||||
typename CGridBlockCluster_BlockId_To_GM10_GN10,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t GK0PerBlock,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// GM0 and GN0 need to known at compile-time
|
||||
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
|
||||
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong! GM0 and GN0 need to be known at compile-time");
|
||||
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (
|
||||
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
|
||||
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
|
||||
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
|
||||
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
|
||||
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
|
||||
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
|
||||
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
|
||||
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
|
||||
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
|
||||
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
|
||||
const index_t GM10 = GM1 / GM11;
|
||||
const index_t GN10 = GN1 / GN11;
|
||||
|
||||
const index_t grid_size = GM10 * GN10;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
|
||||
{
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
|
||||
{
|
||||
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
constexpr auto BM = GM0 * GM11;
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
|
||||
BM1PerThreadBM11>{};
|
||||
constexpr auto BN1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_merge_transform(make_tuple(GN0, GN11))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_unmerge_transform(make_tuple(BN0, BN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_grid_block_cluster_blockid_to_gm10_gn10;
|
||||
}
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
|
||||
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
|
||||
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
||||
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
|
||||
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, igm10, 0, 0),
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, ign10, 0, 0),
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_gk0_bm_gk1),
|
||||
decltype(b_block_desc_gk0_bn_gk1),
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t gk0_block_on_grid = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
|
||||
gk0_block_on_grid += 2 * GK0PerBlock;
|
||||
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
||||
I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
|
||||
|
||||
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
Sequence<1,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
|
||||
1,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_multi_index(igm10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
|
||||
ign10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
|
||||
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,605 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r2.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AKM0M1GridDesc,
|
||||
typename BKN0N1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K == b_k_n_grid_desc.GetLength(I0)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
|
||||
{
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
|
||||
a_k_m_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
|
||||
{
|
||||
const auto K = b_k_n_grid_desc.GetLength(I0);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
|
||||
b_k_n_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
|
||||
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m0_m1_grid_desc),
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k_m0_m1_grid_desc,
|
||||
make_multi_index(0, im0, 0),
|
||||
a_k_m0_m1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n0_n1_grid_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k_n0_n1_grid_desc,
|
||||
make_multi_index(0, in0, 0),
|
||||
b_k_n0_n1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||
AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||
BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,593 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0M0M1K1GridDesc,
|
||||
typename BK0N0N1K1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
typename M11N11ThreadClusterM110Xs,
|
||||
typename M11N11ThreadClusterN110Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
|
||||
{
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
|
||||
{
|
||||
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
|
||||
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
||||
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_grid_desc),
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_k0_m0_m1_k1_grid_desc,
|
||||
make_multi_index(0, im0, 0, 0),
|
||||
a_k0_m0_m1_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n0_n1_k1_grid_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_k0_n0_n1_k1_grid_desc,
|
||||
make_multi_index(0, in0, 0, 0),
|
||||
b_k0_n0_n1_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
|
||||
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
458
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
Normal file
458
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
Normal file
@@ -0,0 +1,458 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_V2_HPP
|
||||
#define CK_GRIDWISE_GEMM_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
index_t KPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalStepHacks,
|
||||
typename BGlobalStepHacks,
|
||||
typename CGlobalStepHacks,
|
||||
typename AGlobalMoveSliceWindowStepHacks,
|
||||
typename BGlobalMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
constexpr auto max_lds_align =
|
||||
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return a_block_space_size * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
// const auto E = a_e_k_global_desc.GetLength(I0);
|
||||
const auto K = a_e_k_global_desc.GetLength(I1);
|
||||
|
||||
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
|
||||
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
// divide block work by [M, N]
|
||||
#if 0
|
||||
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
||||
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
|
||||
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
|
||||
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
|
||||
|
||||
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#else
|
||||
// Hack: this force result into SGPR
|
||||
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
|
||||
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
|
||||
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
|
||||
const index_t k_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
|
||||
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
|
||||
|
||||
const index_t ho_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#endif
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align =
|
||||
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
|
||||
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const auto k_thread_id = c_thread_mtx_index.k;
|
||||
const auto ho_thread_id = c_thread_mtx_index.h;
|
||||
const auto wo_thread_id = c_thread_mtx_index.w;
|
||||
|
||||
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t ho_thread_data_on_global =
|
||||
ho_block_data_on_global + ho_thread_id * HoPerThread;
|
||||
const index_t wo_thread_data_on_global =
|
||||
wo_block_data_on_global + wo_thread_id * WoPerThread;
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_e_k_global_desc),
|
||||
decltype(a_e_k_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_e_k_global_desc,
|
||||
make_multi_index(0, k_block_data_on_global),
|
||||
a_e_k_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto b_threadwise_transfer =
|
||||
ThreadwiseTensorSliceTransfer_v2<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_e_n_ho_wo_global_desc),
|
||||
decltype(b_e_n_ho_wo_thread_desc),
|
||||
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
true>(
|
||||
b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
BGlobalMoveSliceWindowStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t e_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
e_block_data_begin += 2 * EPerBlock;
|
||||
|
||||
} while(e_block_data_begin < E - 2 * EPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + k_thread_id * KPerThread;
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
decltype(c_k_n_ho_wo_global_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
c_k_n_ho_wo_global_desc,
|
||||
make_multi_index(
|
||||
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
||||
.Run(c_k_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
c_global_buf,
|
||||
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
// pass tensor descriptor by reference
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by their pointers
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
|
||||
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
|
||||
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by void*
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const void* p_a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const void* p_b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const void* p_c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
|
||||
const auto b_e_n_ho_wo_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
|
||||
const auto c_k_n_ho_wo_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
1594
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
Normal file
1594
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,325 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
|
||||
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer,
|
||||
index_t NumPrefetch,
|
||||
bool HasMainLoop>
|
||||
struct GridwiseGemmPipeline_v1;
|
||||
|
||||
// 1-stage prefetch
|
||||
template <typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer,
|
||||
bool HasMainLoop>
|
||||
struct GridwiseGemmPipeline_v1<AGridDesc,
|
||||
ABlockDesc,
|
||||
ABlockTransfer,
|
||||
AGridBuffer,
|
||||
ABlockBuffer,
|
||||
ABlockTransferStep,
|
||||
BGridDesc,
|
||||
BBlockDesc,
|
||||
BBlockTransfer,
|
||||
BGridBuffer,
|
||||
BBlockBuffer,
|
||||
BBlockTransferStep,
|
||||
BlockwiseGemm,
|
||||
CThreadBuffer,
|
||||
1,
|
||||
HasMainLoop>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
#if 0
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#else
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// 2-stage prefetch
|
||||
template <typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer,
|
||||
bool HasMainLoop>
|
||||
struct GridwiseGemmPipeline_v1<AGridDesc,
|
||||
ABlockDesc,
|
||||
ABlockTransfer,
|
||||
AGridBuffer,
|
||||
ABlockBuffer,
|
||||
ABlockTransferStep,
|
||||
BGridDesc,
|
||||
BBlockDesc,
|
||||
BBlockTransfer,
|
||||
BGridBuffer,
|
||||
BBlockBuffer,
|
||||
BBlockTransferStep,
|
||||
BlockwiseGemm,
|
||||
CThreadBuffer,
|
||||
2,
|
||||
HasMainLoop>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// preload data into LDS
|
||||
{
|
||||
// Read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
// Move
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Read 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
// Move
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Write i
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
|
||||
|
||||
// Read i+2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm i
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Move
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Write i+1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
|
||||
|
||||
// Read i+3
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm i+1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// Write num_loop - 2
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm num_loop - 2
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Write num_loop - 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm num_loop - 1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,600 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
index_t NumPrefetch = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0_n_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned =
|
||||
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K0 / K0PerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_k0_m_k1),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0_m_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0_n_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
CElementwiseOperation,
|
||||
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2]),
|
||||
c_element_op};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_grid_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,634 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
|
||||
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
|
||||
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
|
||||
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
|
||||
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(KBatch),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_kbatch_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_kbatch_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto a_b_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
|
||||
Number<MPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto b_b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
|
||||
Number<NPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
|
||||
CElementwiseOperation,
|
||||
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2]),
|
||||
c_element_op};
|
||||
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf);
|
||||
}
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,743 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CMNGridDesc,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
|
||||
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
|
||||
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
|
||||
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
|
||||
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(KBatch),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerXDL>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerXDL>{}));
|
||||
}
|
||||
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto a_b_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
|
||||
Number<MPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto b_b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
|
||||
Number<NPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatC*>(p_shared_block),
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
static_assert(M1 == MWaves, "");
|
||||
static_assert(N1 == NWaves, "");
|
||||
static_assert(M2 * M3 * M4 == MPerXDL, "");
|
||||
static_assert(N2 == NPerXDL, "");
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
|
||||
M1,
|
||||
M2,
|
||||
M3,
|
||||
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
|
||||
N1,
|
||||
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1<
|
||||
BlockSize, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWaves * MPerXDL,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWaves * NPerXDL>, // BlockSliceLengths,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatC, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun
|
||||
{c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
make_multi_index(0, CShuffleMRepeatPerShuffle * MWaves * MPerXDL, 0, 0);
|
||||
constexpr auto nxdlperwave_forward_step =
|
||||
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWaves * NPerXDL);
|
||||
constexpr auto nxdlperwave_backward_step =
|
||||
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWaves * NPerXDL);
|
||||
|
||||
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
|
||||
constexpr auto mxdlperwave = mxdlperwave_iter;
|
||||
|
||||
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
|
||||
constexpr bool nxdlperwave_forward_sweep =
|
||||
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
|
||||
|
||||
constexpr index_t nxdlperwave_value =
|
||||
nxdlperwave_forward_sweep
|
||||
? nxdlperwave_iter
|
||||
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
|
||||
|
||||
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
|
||||
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_buf);
|
||||
|
||||
// make sure it's safe to do ds_read
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global
|
||||
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
// move on nxdlperwave dimension
|
||||
if constexpr(nxdlperwave_forward_sweep &&
|
||||
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
nxdlperwave_forward_step);
|
||||
}
|
||||
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
nxdlperwave_backward_step);
|
||||
}
|
||||
});
|
||||
|
||||
// move on mxdlperwave dimension
|
||||
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,762 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v3r1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1Value,
|
||||
index_t BK1Value,
|
||||
index_t MPerXdl,
|
||||
index_t NPerXdl,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
constexpr auto max_lds_align = AK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * AK1, AK1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(AK0, Number<MPerBlock>{}, AK1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_ak0_m_ak1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
constexpr auto max_lds_align = BK1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * BK1, BK1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(BK0, Number<NPerBlock>{}, BK1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_bk0_n_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
constexpr auto
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{},
|
||||
Number<MWave * MPerXdl>{},
|
||||
I1,
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{},
|
||||
Number<NWave * NPerXdl>{}));
|
||||
|
||||
return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1);
|
||||
|
||||
constexpr auto b_block_space_size_aligned =
|
||||
math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1);
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(FloatAB),
|
||||
c_block_size * sizeof(FloatCShuffle));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
|
||||
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
|
||||
// "wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
|
||||
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
|
||||
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K / KPerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
MBlock, Number<MXdlPerWave>{}, Number<MWave * MPerXdl>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
NBlock, Number<NXdlPerWave>{}, Number<NWave * NPerXdl>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1, BK1);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<AK0, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<BK0, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t k_pack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
k_pack>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0>{},
|
||||
Sequence<2, 4, 5, 6>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatCShuffle,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1<
|
||||
BlockSize, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
|
||||
FloatCShuffle, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
decltype(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
|
||||
5, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(0, 0, 0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
|
||||
constexpr auto nxdlperwave_forward_step =
|
||||
make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
|
||||
constexpr auto nxdlperwave_backward_step =
|
||||
make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
|
||||
|
||||
static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
|
||||
constexpr auto mxdlperwave = mxdlperwave_iter;
|
||||
|
||||
static_for<0,
|
||||
NXdlPerWave,
|
||||
CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
|
||||
constexpr bool nxdlperwave_forward_sweep =
|
||||
(mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
|
||||
|
||||
constexpr index_t nxdlperwave_value =
|
||||
nxdlperwave_forward_sweep
|
||||
? nxdlperwave_iter
|
||||
: (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
|
||||
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to do ds_read
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global
|
||||
c_block_copy_lds_to_global.Run(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c_grid_buf);
|
||||
|
||||
// move on nxdlperwave dimension
|
||||
if constexpr(nxdlperwave_forward_sweep &&
|
||||
(nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_forward_step);
|
||||
}
|
||||
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_backward_step);
|
||||
}
|
||||
});
|
||||
|
||||
// move on mxdlperwave dimension
|
||||
if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
mxdlperwave_forward_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,800 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V3R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v6r2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v3r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC* __restrict__ p_c0_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_c0_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
typename C0GridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXdl,
|
||||
index_t NPerXdl,
|
||||
index_t K1Value,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0_n_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
constexpr auto
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{},
|
||||
Number<MWave * MPerXdl>{},
|
||||
I1,
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{},
|
||||
Number<NWave * NPerXdl>{}));
|
||||
|
||||
return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned =
|
||||
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(FloatAB),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K0 / K0PerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
const CGridDesc_M_N_& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
MBlock, Number<MXdlPerWave>{}, Number<MWave * MPerXdl>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
NBlock, Number<NXdlPerWave>{}, Number<NWave * NPerXdl>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
C0GridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC* __restrict__ p_c0_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c0_grid,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_k0_m_k1),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0_m_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0_n_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatC*>(p_shared),
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0>{},
|
||||
Sequence<2, 4, 5, 6>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3, 7>{})
|
||||
|
||||
);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r2<
|
||||
BlockSize, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
|
||||
FloatC, // typename Src0Data,
|
||||
FloatC, // typename Src1Data,
|
||||
FloatC, // typename DstData,
|
||||
decltype(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
|
||||
5, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(0, 0, 0, 0, 0, 0),
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
|
||||
constexpr auto nxdlperwave_forward_step =
|
||||
make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
|
||||
constexpr auto nxdlperwave_backward_step =
|
||||
make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
|
||||
|
||||
static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
|
||||
constexpr auto mxdlperwave = mxdlperwave_iter;
|
||||
|
||||
static_for<0,
|
||||
NXdlPerWave,
|
||||
CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
|
||||
constexpr bool nxdlperwave_forward_sweep =
|
||||
(mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
|
||||
|
||||
constexpr index_t nxdlperwave_value =
|
||||
nxdlperwave_forward_sweep
|
||||
? nxdlperwave_iter
|
||||
: (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
|
||||
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_buf);
|
||||
|
||||
// make sure it's safe to do ds_read
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global
|
||||
c_block_copy_lds_to_global.Run(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c_block_buf,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c0_grid_buf,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c_grid_buf);
|
||||
|
||||
// move on nxdlperwave dimension
|
||||
if constexpr(nxdlperwave_forward_sweep &&
|
||||
(nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_forward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_forward_step);
|
||||
}
|
||||
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_backward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_backward_step);
|
||||
}
|
||||
});
|
||||
|
||||
// move on mxdlperwave dimension
|
||||
if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
mxdlperwave_forward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
mxdlperwave_forward_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,837 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v6r3.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
typename C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v3r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC* __restrict__ p_c0_grid,
|
||||
const FloatC* __restrict__ p_c1_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_c0_grid,
|
||||
p_c1_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
typename C0GridDesc_M_N,
|
||||
typename C1GridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXdl,
|
||||
index_t NPerXdl,
|
||||
index_t K1Value,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0_n_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
constexpr auto
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{},
|
||||
Number<MWave * MPerXdl>{},
|
||||
I1,
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{},
|
||||
Number<NWave * NPerXdl>{}));
|
||||
|
||||
return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned =
|
||||
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(FloatAB),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K0 / K0PerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
const CGridDesc_M_N_& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
MBlock, Number<MXdlPerWave>{}, Number<MWave * MPerXdl>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
NBlock, Number<NXdlPerWave>{}, Number<NWave * NPerXdl>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
C0GridDesc_M_N{}))>;
|
||||
|
||||
using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
C1GridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC* __restrict__ p_c0_grid,
|
||||
const FloatC* __restrict__ p_c1_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c0_grid,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c1_grid,
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_k0_m_k1),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0_m_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0_n_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
static_cast<FloatC*>(p_shared),
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0>{},
|
||||
Sequence<2, 4, 5, 6>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3, 7>{})
|
||||
|
||||
);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r3<
|
||||
BlockSize, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
|
||||
FloatC, // typename Src0Data,
|
||||
FloatC, // typename Src1Data,
|
||||
FloatC, // typename Src2Data,
|
||||
FloatC, // typename DstData,
|
||||
decltype(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
|
||||
5, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(0, 0, 0, 0, 0, 0),
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
|
||||
constexpr auto nxdlperwave_forward_step =
|
||||
make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
|
||||
constexpr auto nxdlperwave_backward_step =
|
||||
make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
|
||||
|
||||
static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
|
||||
constexpr auto mxdlperwave = mxdlperwave_iter;
|
||||
|
||||
static_for<0,
|
||||
NXdlPerWave,
|
||||
CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
|
||||
constexpr bool nxdlperwave_forward_sweep =
|
||||
(mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
|
||||
|
||||
constexpr index_t nxdlperwave_value =
|
||||
nxdlperwave_forward_sweep
|
||||
? nxdlperwave_iter
|
||||
: (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
|
||||
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_buf);
|
||||
|
||||
// make sure it's safe to do ds_read
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global
|
||||
c_block_copy_lds_to_global.Run(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c_block_buf,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c0_grid_buf,
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c1_grid_buf,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
c_grid_buf);
|
||||
|
||||
// move on nxdlperwave dimension
|
||||
if constexpr(nxdlperwave_forward_sweep &&
|
||||
(nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_forward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveSrc2SliceWindow(
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_forward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_forward_step);
|
||||
}
|
||||
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_backward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveSrc2SliceWindow(
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_backward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
nxdlperwave_backward_step);
|
||||
}
|
||||
});
|
||||
|
||||
// move on mxdlperwave dimension
|
||||
if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
mxdlperwave_forward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveSrc2SliceWindow(
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
mxdlperwave_forward_step);
|
||||
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
mxdlperwave_forward_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,79 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_SET_BUFFER_VALUE_HPP
|
||||
#define CK_GRIDWISE_SET_BUFFER_VALUE_HPP
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize, typename DataType, typename Grid1dBufferDescType>
|
||||
__global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffer_desc,
|
||||
DataType* const __restrict__ p_global,
|
||||
DataType value)
|
||||
|
||||
{
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
const index_t thread_global_id = block_global_id * BlockSize + thread_local_id;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, DataType, 1, true> value_buf;
|
||||
|
||||
value_buf(I0) = value;
|
||||
|
||||
constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
auto global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_global, grid_1d_buffer_desc.GetElementSpaceSize());
|
||||
|
||||
if(thread_global_id < grid_1d_buffer_desc.GetElementSize())
|
||||
{
|
||||
auto threadwise_store = ThreadwiseTensorSliceTransfer_v1r3<DataType,
|
||||
DataType,
|
||||
decltype(val_buff_desc),
|
||||
Grid1dBufferDescType,
|
||||
PassThroughOp,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{});
|
||||
|
||||
threadwise_store.Run(
|
||||
val_buff_desc, make_tuple(I0), value_buf, grid_1d_buffer_desc, global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user