mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
universal streamk fp8 changes (#1665)
* universal streamk fp8 changes & ckprofiler instances * revert strides to -1 and verification options * fp8 exclusion on pre-gfx94 for universal_streamk * PR review based revisions: permissions reverted, removed hip err checks --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
fb1ccfa9df
commit
d6d4c2788b
816
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
816
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
@@ -14,6 +14,8 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/workgroup_barrier.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -38,7 +40,7 @@ __global__ void
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -62,7 +64,13 @@ __global__ void
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
karg.p_workspace_);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_}
|
||||
p_c_grid{p_c_grid_},
|
||||
block_2_ctile_map_streamk(
|
||||
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
|
||||
|
||||
{
|
||||
}
|
||||
@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
StreamKReductionStrategy::Atomic,
|
||||
8,
|
||||
4>
|
||||
block_2_ctile_map_streamk;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
|
||||
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
|
||||
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
|
||||
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
|
||||
}
|
||||
|
||||
using BlockwiseGemmPipe =
|
||||
remove_cvref_t<decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetClusterLengthReduction()
|
||||
{
|
||||
// TODO: assume C is row major
|
||||
// TODO: we always first loop over N, then M
|
||||
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
|
||||
constexpr auto NPerBlockReduction =
|
||||
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
constexpr auto MPerBlockReduction =
|
||||
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
|
||||
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
|
||||
{
|
||||
const auto c_partial_acc_block_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(NPerBlock, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(I1, MPerBlock));
|
||||
}
|
||||
}();
|
||||
return c_partial_acc_block_m_n;
|
||||
}
|
||||
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
Problem& problem)
|
||||
Problem& problem,
|
||||
void* p_workspace)
|
||||
{
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
if(is_reduction_block)
|
||||
{
|
||||
// descriptors
|
||||
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
|
||||
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
|
||||
const auto reduce_thread_cluster_idx =
|
||||
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
|
||||
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
|
||||
|
||||
constexpr auto MReduceIters = math::integer_divide_ceil(
|
||||
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
|
||||
constexpr auto NReduceIters = math::integer_divide_ceil(
|
||||
Number<NPerBlock>{},
|
||||
cluster_length_reduce.At(I1) *
|
||||
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
|
||||
|
||||
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
constexpr auto acc_thread_buf_store_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
|
||||
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
|
||||
|
||||
constexpr auto partial_acc_load_step_n =
|
||||
make_multi_index(0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_m =
|
||||
make_multi_index(cluster_length_reduce.At(I0), 0);
|
||||
|
||||
constexpr auto partial_acc_store_step_n =
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_m =
|
||||
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
parcial_acc_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
acc_buf;
|
||||
|
||||
// start to compute
|
||||
auto reduction_idx =
|
||||
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
|
||||
reduction_idx, problem.M, problem.N);
|
||||
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
|
||||
uint32_t tile_acc_offset_start =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
|
||||
uint32_t tile_acc_offset_end =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
|
||||
1);
|
||||
__syncthreads();
|
||||
|
||||
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
AccDataType, // SrcData,
|
||||
AccDataType, // DstData,
|
||||
decltype(c_partial_acc_block_m_n), // SrcDesc,
|
||||
decltype(acc_thread_buf_load_desc), // DstDesc,
|
||||
Sequence<1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1>, // DimAccessOrder,
|
||||
1, // SrcVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
|
||||
1, // SrcScalarStrideInVector,
|
||||
false // SrcResetCoordinateAfterRun,
|
||||
>{c_partial_acc_block_m_n,
|
||||
make_multi_index(thread_m_cluster_id,
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock)};
|
||||
|
||||
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType, // SrcData,
|
||||
CDataType, // DstData,
|
||||
decltype(acc_thread_buf_store_desc), // SrcDesc,
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<1,
|
||||
1,
|
||||
1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder,
|
||||
3, // DstVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
|
||||
1, // DstScalarStrideInVector,
|
||||
false // DstResetCoordinateAfterRun,
|
||||
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
|
||||
thread_m_cluster_id,
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock),
|
||||
CElementwiseOperation{}};
|
||||
|
||||
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
p_semaphore[reduction_idx] = 0;
|
||||
}
|
||||
using Accumulation = ck::detail::
|
||||
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
|
||||
|
||||
for(int i_m = 0; i_m < MReduceIters; i_m++)
|
||||
{
|
||||
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
|
||||
acc_buf.Clear();
|
||||
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
|
||||
{
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global,
|
||||
AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) +
|
||||
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
|
||||
c_partial_acc_block_m_n.GetElementSpaceSize());
|
||||
|
||||
acc_load.Run(c_partial_acc_block_m_n,
|
||||
c_partial_acc_buf,
|
||||
acc_thread_buf_load_desc,
|
||||
make_tuple(I0, I0),
|
||||
parcial_acc_buf);
|
||||
|
||||
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
|
||||
[&](auto i_vec) {
|
||||
constexpr auto offset =
|
||||
acc_thread_buf_load_desc.CalculateOffset(
|
||||
make_tuple(0, i_vec));
|
||||
Accumulation::Calculate(acc_buf(Number<offset>{}),
|
||||
parcial_acc_buf[Number<offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
if(thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock <
|
||||
NPerBlock)
|
||||
{
|
||||
acc_store.Run(acc_thread_buf_store_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
acc_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
if constexpr(NReduceIters != 1)
|
||||
{
|
||||
if constexpr(i_n_reduce != (NReduceIters - 1))
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n_reverse);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n_reverse);
|
||||
}
|
||||
}
|
||||
});
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_m);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_m);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// offset for last acc buffer of this block
|
||||
uint32_t block_acc_offset =
|
||||
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
|
||||
MPerBlock * NPerBlock;
|
||||
while(true)
|
||||
{
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end - 1, tile_idx, iter_offset);
|
||||
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
problem.K,
|
||||
problem.KPadded,
|
||||
problem.StrideA,
|
||||
problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
|
||||
problem.KPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.StrideB,
|
||||
problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto block_work_idx =
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
|
||||
|
||||
@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
|
||||
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
|
||||
.GetElementSpaceSize());
|
||||
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
|
||||
// LDS to global partial acc
|
||||
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
// InMemoryDataOperationEnum::Set, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave *
|
||||
NPerXdl>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CShuffleDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_element_op};
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
make_tuple(0, 0, 0, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_partial_acc_buf),
|
||||
InMemoryDataOperationEnum::Set>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
c_partial_acc_buf);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
// increase the counter for this tile
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
wg_barrier.inc(tile_idx);
|
||||
}
|
||||
}
|
||||
} // shuffle c and write-out end
|
||||
|
||||
// exit condition
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
} // while loop
|
||||
|
||||
} // for loop
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
Problem& problem)
|
||||
Problem& problem,
|
||||
void* p_workspace)
|
||||
{
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(
|
||||
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
if(is_reduction_block)
|
||||
{
|
||||
// descriptors
|
||||
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
|
||||
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
|
||||
const auto reduce_thread_cluster_idx =
|
||||
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
|
||||
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
|
||||
|
||||
constexpr auto MReduceIters = math::integer_divide_ceil(
|
||||
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
|
||||
constexpr auto NReduceIters = math::integer_divide_ceil(
|
||||
Number<NPerBlock>{},
|
||||
cluster_length_reduce.At(I1) *
|
||||
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
|
||||
|
||||
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
constexpr auto acc_thread_buf_store_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
|
||||
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
|
||||
|
||||
constexpr auto partial_acc_load_step_n =
|
||||
make_multi_index(0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_m =
|
||||
make_multi_index(cluster_length_reduce.At(I0), 0);
|
||||
|
||||
constexpr auto partial_acc_store_step_n =
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_m =
|
||||
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
parcial_acc_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
acc_buf;
|
||||
|
||||
// start to compute
|
||||
auto reduction_idx =
|
||||
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
|
||||
reduction_idx, problem.M, problem.N);
|
||||
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
|
||||
uint32_t tile_acc_offset_start =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
|
||||
uint32_t tile_acc_offset_end =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
|
||||
1);
|
||||
|
||||
uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
p_semaphore[reduction_idx] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
AccDataType, // SrcData,
|
||||
AccDataType, // DstData,
|
||||
decltype(c_partial_acc_block_m_n), // SrcDesc,
|
||||
decltype(acc_thread_buf_load_desc), // DstDesc,
|
||||
Sequence<1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1>, // DimAccessOrder,
|
||||
1, // SrcVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
|
||||
1, // SrcScalarStrideInVector,
|
||||
false // SrcResetCoordinateAfterRun,
|
||||
>{c_partial_acc_block_m_n,
|
||||
make_multi_index(thread_m_cluster_id,
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock)};
|
||||
|
||||
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType, // SrcData,
|
||||
CDataType, // DstData,
|
||||
decltype(acc_thread_buf_store_desc), // SrcDesc,
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<1,
|
||||
1,
|
||||
1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder,
|
||||
3, // DstVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
|
||||
1, // DstScalarStrideInVector,
|
||||
false // DstResetCoordinateAfterRun,
|
||||
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
|
||||
thread_m_cluster_id,
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock),
|
||||
CElementwiseOperation{}};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0) {
|
||||
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
|
||||
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
|
||||
}
|
||||
#endif
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(&p_semaphore[reduction_idx], 1);
|
||||
}
|
||||
|
||||
wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
|
||||
using Accumulation = ck::detail::
|
||||
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
|
||||
|
||||
for(int i_m = 0; i_m < MReduceIters; i_m++)
|
||||
{
|
||||
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
|
||||
acc_buf.Clear();
|
||||
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
|
||||
{
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global,
|
||||
AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) +
|
||||
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
|
||||
c_partial_acc_block_m_n.GetElementSpaceSize());
|
||||
|
||||
acc_load.Run(c_partial_acc_block_m_n,
|
||||
c_partial_acc_buf,
|
||||
acc_thread_buf_load_desc,
|
||||
make_tuple(I0, I0),
|
||||
parcial_acc_buf);
|
||||
|
||||
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
|
||||
[&](auto i_vec) {
|
||||
constexpr auto offset =
|
||||
acc_thread_buf_load_desc.CalculateOffset(
|
||||
make_tuple(0, i_vec));
|
||||
Accumulation::Calculate(acc_buf(Number<offset>{}),
|
||||
parcial_acc_buf[Number<offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
if(thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock <
|
||||
NPerBlock)
|
||||
{
|
||||
acc_store.Run(acc_thread_buf_store_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
acc_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
if constexpr(NReduceIters != 1)
|
||||
{
|
||||
if constexpr(i_n_reduce != (NReduceIters - 1))
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n_reverse);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n_reverse);
|
||||
}
|
||||
}
|
||||
});
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_m);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_m);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// offset for last acc buffer of this block
|
||||
uint32_t block_acc_offset =
|
||||
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
|
||||
MPerBlock * NPerBlock;
|
||||
while(true)
|
||||
{
|
||||
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end - 1, tile_idx, iter_offset);
|
||||
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
problem.K,
|
||||
problem.KPadded,
|
||||
problem.StrideA,
|
||||
problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
|
||||
problem.KPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.StrideB,
|
||||
problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto block_work_idx =
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
|
||||
|
||||
@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
|
||||
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared_0),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
|
||||
.GetElementSpaceSize());
|
||||
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
|
||||
// LDS to global partial acc
|
||||
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
// InMemoryDataOperationEnum::Set, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave *
|
||||
NPerXdl>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CShuffleDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
make_tuple(0, 0, 0, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_partial_acc_buf),
|
||||
InMemoryDataOperationEnum::Set>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
c_partial_acc_buf);
|
||||
}
|
||||
}
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
});
|
||||
}
|
||||
// exit condition
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
// increase the counter for this tile
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
wg_barrier.inc(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user