mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
universal streamk fp8 changes (#1665)
* universal streamk fp8 changes & ckprofiler instances
* revert strides to -1 and verification options
* fp8 exclusion on pre-gfx94 for universal_streamk
* PR review based revisions: permissions reverted, removed hip err checks
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: d6d4c2788b]
This commit is contained in:
committed by
GitHub
parent
326639e80c
commit
5a5bfe14f4
382
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
382
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
@@ -131,6 +131,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
@@ -147,26 +148,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
hip_check_error(hipMemsetAsync(
|
||||
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
dim3 grid_dim;
|
||||
if(arg.Grid_size < 0)
|
||||
{
|
||||
int occupancy, num_cu;
|
||||
hipError_t rtn;
|
||||
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, 0);
|
||||
hip_check_error(rtn);
|
||||
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, 0));
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
rtn = hipGetDevice(&dev);
|
||||
hip_check_error(rtn);
|
||||
rtn = hipGetDeviceProperties(&dev_prop, dev);
|
||||
hip_check_error(rtn);
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
arg.Grid_size = num_cu * occupancy;
|
||||
grid_dim = arg.Grid_size;
|
||||
}
|
||||
@@ -196,8 +198,31 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
else
|
||||
{
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
}
|
||||
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
char* workspace_semaphore =
|
||||
reinterpret_cast<char*>(arg.p_workspace_) +
|
||||
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
|
||||
sizeof(GemmAccDataType));
|
||||
auto preprocess = [&]() {
|
||||
hipMemsetAsync(
|
||||
workspace_semaphore,
|
||||
0,
|
||||
// sizeof(uint32_t),
|
||||
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
|
||||
stream_config.stream_id_);
|
||||
};
|
||||
|
||||
ave_time = launch_and_time_kernel_with_preprocess(
|
||||
stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -211,14 +236,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
@@ -340,53 +363,49 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -396,14 +415,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,6 +434,29 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
}
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
@@ -464,8 +503,205 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
CElementwiseOperation)
|
||||
{
|
||||
|
||||
return Argument{
|
||||
p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
int occupancy, num_cu;
|
||||
const auto calculate_grid_size = [&](const auto& kernel) {
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
Grid_size = num_cu * occupancy;
|
||||
};
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
calculate_grid_size(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
816
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
816
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
@@ -14,6 +14,8 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/workgroup_barrier.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -38,7 +40,7 @@ __global__ void
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -62,7 +64,13 @@ __global__ void
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
karg.p_workspace_);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_}
|
||||
p_c_grid{p_c_grid_},
|
||||
block_2_ctile_map_streamk(
|
||||
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
|
||||
|
||||
{
|
||||
}
|
||||
@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
StreamKReductionStrategy::Atomic,
|
||||
8,
|
||||
4>
|
||||
block_2_ctile_map_streamk;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
|
||||
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
|
||||
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
|
||||
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
|
||||
}
|
||||
|
||||
using BlockwiseGemmPipe =
|
||||
remove_cvref_t<decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetClusterLengthReduction()
|
||||
{
|
||||
// TODO: assume C is row major
|
||||
// TODO: we always first loop over N, then M
|
||||
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
|
||||
constexpr auto NPerBlockReduction =
|
||||
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
constexpr auto MPerBlockReduction =
|
||||
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
|
||||
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
|
||||
{
|
||||
const auto c_partial_acc_block_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(NPerBlock, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(I1, MPerBlock));
|
||||
}
|
||||
}();
|
||||
return c_partial_acc_block_m_n;
|
||||
}
|
||||
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
Problem& problem)
|
||||
Problem& problem,
|
||||
void* p_workspace)
|
||||
{
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
if(is_reduction_block)
|
||||
{
|
||||
// descriptors
|
||||
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
|
||||
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
|
||||
const auto reduce_thread_cluster_idx =
|
||||
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
|
||||
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
|
||||
|
||||
constexpr auto MReduceIters = math::integer_divide_ceil(
|
||||
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
|
||||
constexpr auto NReduceIters = math::integer_divide_ceil(
|
||||
Number<NPerBlock>{},
|
||||
cluster_length_reduce.At(I1) *
|
||||
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
|
||||
|
||||
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
constexpr auto acc_thread_buf_store_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
|
||||
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
|
||||
|
||||
constexpr auto partial_acc_load_step_n =
|
||||
make_multi_index(0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_m =
|
||||
make_multi_index(cluster_length_reduce.At(I0), 0);
|
||||
|
||||
constexpr auto partial_acc_store_step_n =
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_m =
|
||||
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
parcial_acc_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
acc_buf;
|
||||
|
||||
// start to compute
|
||||
auto reduction_idx =
|
||||
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
|
||||
reduction_idx, problem.M, problem.N);
|
||||
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
|
||||
uint32_t tile_acc_offset_start =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
|
||||
uint32_t tile_acc_offset_end =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
|
||||
1);
|
||||
__syncthreads();
|
||||
|
||||
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
AccDataType, // SrcData,
|
||||
AccDataType, // DstData,
|
||||
decltype(c_partial_acc_block_m_n), // SrcDesc,
|
||||
decltype(acc_thread_buf_load_desc), // DstDesc,
|
||||
Sequence<1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1>, // DimAccessOrder,
|
||||
1, // SrcVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
|
||||
1, // SrcScalarStrideInVector,
|
||||
false // SrcResetCoordinateAfterRun,
|
||||
>{c_partial_acc_block_m_n,
|
||||
make_multi_index(thread_m_cluster_id,
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock)};
|
||||
|
||||
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType, // SrcData,
|
||||
CDataType, // DstData,
|
||||
decltype(acc_thread_buf_store_desc), // SrcDesc,
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<1,
|
||||
1,
|
||||
1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder,
|
||||
3, // DstVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
|
||||
1, // DstScalarStrideInVector,
|
||||
false // DstResetCoordinateAfterRun,
|
||||
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
|
||||
thread_m_cluster_id,
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock),
|
||||
CElementwiseOperation{}};
|
||||
|
||||
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
p_semaphore[reduction_idx] = 0;
|
||||
}
|
||||
using Accumulation = ck::detail::
|
||||
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
|
||||
|
||||
for(int i_m = 0; i_m < MReduceIters; i_m++)
|
||||
{
|
||||
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
|
||||
acc_buf.Clear();
|
||||
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
|
||||
{
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global,
|
||||
AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) +
|
||||
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
|
||||
c_partial_acc_block_m_n.GetElementSpaceSize());
|
||||
|
||||
acc_load.Run(c_partial_acc_block_m_n,
|
||||
c_partial_acc_buf,
|
||||
acc_thread_buf_load_desc,
|
||||
make_tuple(I0, I0),
|
||||
parcial_acc_buf);
|
||||
|
||||
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
|
||||
[&](auto i_vec) {
|
||||
constexpr auto offset =
|
||||
acc_thread_buf_load_desc.CalculateOffset(
|
||||
make_tuple(0, i_vec));
|
||||
Accumulation::Calculate(acc_buf(Number<offset>{}),
|
||||
parcial_acc_buf[Number<offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
if(thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock <
|
||||
NPerBlock)
|
||||
{
|
||||
acc_store.Run(acc_thread_buf_store_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
acc_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
if constexpr(NReduceIters != 1)
|
||||
{
|
||||
if constexpr(i_n_reduce != (NReduceIters - 1))
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n_reverse);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n_reverse);
|
||||
}
|
||||
}
|
||||
});
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_m);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_m);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// offset for last acc buffer of this block
|
||||
uint32_t block_acc_offset =
|
||||
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
|
||||
MPerBlock * NPerBlock;
|
||||
while(true)
|
||||
{
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end - 1, tile_idx, iter_offset);
|
||||
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
problem.K,
|
||||
problem.KPadded,
|
||||
problem.StrideA,
|
||||
problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
|
||||
problem.KPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.StrideB,
|
||||
problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto block_work_idx =
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
|
||||
|
||||
@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
|
||||
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
|
||||
.GetElementSpaceSize());
|
||||
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
|
||||
// LDS to global partial acc
|
||||
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
// InMemoryDataOperationEnum::Set, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave *
|
||||
NPerXdl>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CShuffleDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_element_op};
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
make_tuple(0, 0, 0, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_partial_acc_buf),
|
||||
InMemoryDataOperationEnum::Set>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
c_partial_acc_buf);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
// increase the counter for this tile
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
wg_barrier.inc(tile_idx);
|
||||
}
|
||||
}
|
||||
} // shuffle c and write-out end
|
||||
|
||||
// exit condition
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
} // while loop
|
||||
|
||||
} // for loop
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
Problem& problem)
|
||||
Problem& problem,
|
||||
void* p_workspace)
|
||||
{
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(
|
||||
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
if(is_reduction_block)
|
||||
{
|
||||
// descriptors
|
||||
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
|
||||
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
|
||||
const auto reduce_thread_cluster_idx =
|
||||
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
|
||||
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
|
||||
|
||||
constexpr auto MReduceIters = math::integer_divide_ceil(
|
||||
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
|
||||
constexpr auto NReduceIters = math::integer_divide_ceil(
|
||||
Number<NPerBlock>{},
|
||||
cluster_length_reduce.At(I1) *
|
||||
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
|
||||
|
||||
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
constexpr auto acc_thread_buf_store_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
|
||||
|
||||
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
|
||||
|
||||
constexpr auto partial_acc_load_step_n =
|
||||
make_multi_index(0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_load_step_m =
|
||||
make_multi_index(cluster_length_reduce.At(I0), 0);
|
||||
|
||||
constexpr auto partial_acc_store_step_n =
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
cluster_length_reduce.At(I1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock);
|
||||
constexpr auto partial_acc_store_step_m =
|
||||
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
parcial_acc_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>
|
||||
acc_buf;
|
||||
|
||||
// start to compute
|
||||
auto reduction_idx =
|
||||
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
|
||||
reduction_idx, problem.M, problem.N);
|
||||
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
|
||||
uint32_t tile_acc_offset_start =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
|
||||
uint32_t tile_acc_offset_end =
|
||||
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
|
||||
1);
|
||||
|
||||
uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
p_semaphore[reduction_idx] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
AccDataType, // SrcData,
|
||||
AccDataType, // DstData,
|
||||
decltype(c_partial_acc_block_m_n), // SrcDesc,
|
||||
decltype(acc_thread_buf_load_desc), // DstDesc,
|
||||
Sequence<1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1>, // DimAccessOrder,
|
||||
1, // SrcVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
|
||||
1, // SrcScalarStrideInVector,
|
||||
false // SrcResetCoordinateAfterRun,
|
||||
>{c_partial_acc_block_m_n,
|
||||
make_multi_index(thread_m_cluster_id,
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock)};
|
||||
|
||||
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType, // SrcData,
|
||||
CDataType, // DstData,
|
||||
decltype(acc_thread_buf_store_desc), // SrcDesc,
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<1,
|
||||
1,
|
||||
1,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder,
|
||||
3, // DstVectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
|
||||
1, // DstScalarStrideInVector,
|
||||
false // DstResetCoordinateAfterRun,
|
||||
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
|
||||
thread_m_cluster_id,
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
|
||||
thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock),
|
||||
CElementwiseOperation{}};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0) {
|
||||
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
|
||||
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
|
||||
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
|
||||
}
|
||||
#endif
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(&p_semaphore[reduction_idx], 1);
|
||||
}
|
||||
|
||||
wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
|
||||
using Accumulation = ck::detail::
|
||||
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
|
||||
|
||||
for(int i_m = 0; i_m < MReduceIters; i_m++)
|
||||
{
|
||||
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
|
||||
acc_buf.Clear();
|
||||
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
|
||||
{
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global,
|
||||
AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) +
|
||||
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
|
||||
c_partial_acc_block_m_n.GetElementSpaceSize());
|
||||
|
||||
acc_load.Run(c_partial_acc_block_m_n,
|
||||
c_partial_acc_buf,
|
||||
acc_thread_buf_load_desc,
|
||||
make_tuple(I0, I0),
|
||||
parcial_acc_buf);
|
||||
|
||||
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
|
||||
[&](auto i_vec) {
|
||||
constexpr auto offset =
|
||||
acc_thread_buf_load_desc.CalculateOffset(
|
||||
make_tuple(0, i_vec));
|
||||
Accumulation::Calculate(acc_buf(Number<offset>{}),
|
||||
parcial_acc_buf[Number<offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
if(thread_n_cluster_id *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock <
|
||||
NPerBlock)
|
||||
{
|
||||
acc_store.Run(acc_thread_buf_store_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
acc_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
if constexpr(NReduceIters != 1)
|
||||
{
|
||||
if constexpr(i_n_reduce != (NReduceIters - 1))
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_n_reverse);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_n_reverse);
|
||||
}
|
||||
}
|
||||
});
|
||||
{
|
||||
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
|
||||
partial_acc_load_step_m);
|
||||
acc_store.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
partial_acc_store_step_m);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// offset for last acc buffer of this block
|
||||
uint32_t block_acc_offset =
|
||||
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
|
||||
MPerBlock * NPerBlock;
|
||||
while(true)
|
||||
{
|
||||
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end - 1, tile_idx, iter_offset);
|
||||
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
problem.K,
|
||||
problem.KPadded,
|
||||
problem.StrideA,
|
||||
problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
|
||||
problem.KPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.StrideB,
|
||||
problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto block_work_idx =
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
|
||||
|
||||
@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
|
||||
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared_0),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
|
||||
.GetElementSpaceSize());
|
||||
|
||||
auto c_partial_acc_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
|
||||
// LDS to global partial acc
|
||||
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
// InMemoryDataOperationEnum::Set, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave *
|
||||
NPerXdl>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CShuffleDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
|
||||
// false, othre wise has scratch
|
||||
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_grid_buf),
|
||||
InMemoryDataOperationEnum::AtomicAdd>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
make_tuple(0, 0, 0, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
|
||||
|
||||
c_block_copy_lds_to_partial_acc
|
||||
.template Run<decltype(c_shuffle_block_buf),
|
||||
decltype(c_partial_acc_buf),
|
||||
InMemoryDataOperationEnum::Set>(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
|
||||
c_partial_acc_buf);
|
||||
}
|
||||
}
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
});
|
||||
}
|
||||
// exit condition
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
// increase the counter for this tile
|
||||
workgroup_barrier wg_barrier(p_semaphore);
|
||||
wg_barrier.inc(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user