mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
initial stream-k implementation with example (#699)
* initial stream-k implementation with example
* fix unexpected change in err
* improve a little bit performance by reorganize pipeline.
* improve perf a little bit by swizzle block idx
* add profiler
* update example
* fix spelling
* shrink karg for streamk
* support dynamic buffer using memory coherence glc_slc bit from template
* control memory coherence while construct dynamic buffer
* update reduction for streamk(not ready yet)
* Add template parameter to make_dynamic_buffer to support amd_buffer coherence setting
* fix build issue
* fix several bug
* now result is correct, everything works (but has scratch)
* remove scratch by manually reset coordinate
* update device code
* fix a bug in final reduce
* fix something in example
* update async memset
* fix enum as camel case
* modify coherence enum name
* clean code and use atomic streamk by default
* remove unused var
* throw exception if have empty pointer
* fix format
* fix CI warning
* fix type in init
* modify CI error
* filter out on gfx10+
* restore changed example code
---------
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
[ROCm/composable_kernel commit: e7dca79d27]
This commit is contained in:
@@ -50,6 +50,8 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_f8)
|
||||
|
||||
@@ -33,6 +33,19 @@ struct ProblemSize final
|
||||
ck::index_t StrideC = 4096;
|
||||
};
|
||||
|
||||
struct ProblemSizeStreamK final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
inline bool
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
template <typename ProblemType>
|
||||
bool parse_cmd_args(int, char*[], ProblemType&, ExecutionConfig&)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSize>(int argc,
|
||||
char* argv[],
|
||||
ProblemSize& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeStreamK& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc >= 10)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideC = std::stoi(argv[9]);
|
||||
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.NumSKBlocks = std::stoi(argv[10]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: NumSKBlocks(optional)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
49
example/01_gemm/gemm_xdl_streamk.cpp
Normal file
49
example/01_gemm/gemm_xdl_streamk.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
// using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
|
||||
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
|
||||
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
|
||||
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
|
||||
|
||||
|
||||
// // clang-format on
|
||||
// clang-format on
|
||||
|
||||
using DeviceGemmInstance = DeviceGemmStreamK;
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); }
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
@@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -25,12 +33,37 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 0:
|
||||
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
|
||||
break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
@@ -66,42 +99,114 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
#endif
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
if constexpr(std::is_same<ProblemType, ProblemSize>::value &&
|
||||
!std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
return true;
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
}
|
||||
else if constexpr(std::is_same<ProblemType, ProblemSizeStreamK>::value &&
|
||||
std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
|
||||
{
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
problem_size.NumSKBlocks);
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
if(workspace_size != 0)
|
||||
{
|
||||
workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
#if 0
|
||||
// TODO!!!!!
|
||||
if(workspace_size != 0){
|
||||
float * ws_ptr = reinterpret_cast<float*>(malloc(workspace_size));
|
||||
size_t ws_dwords = workspace_size / sizeof(float);
|
||||
workspace.FromDevice(ws_ptr);
|
||||
|
||||
for(size_t i = 0; i < ws_dwords; i++) {
|
||||
uint32_t rere = reinterpret_cast<uint32_t*>(ws_ptr)[i];
|
||||
printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere);
|
||||
}
|
||||
free(ws_ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -149,3 +254,11 @@ bool run_gemm_example(int argc, char* argv[])
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
bool run_gemm_streamk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeStreamK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
@@ -73,3 +73,72 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename... Args, typename F, typename PreProcessFunc>
|
||||
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up 1 time\n");
|
||||
#endif
|
||||
// warm up
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
const int nrepeat = 10;
|
||||
#if DEBUG_LOG
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
#endif
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
}
|
||||
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
|
||||
float total_time = 0;
|
||||
|
||||
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct ThreadGroupTensorSliceTransfer_v6r1r2
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.template Run<SrcBuffer, DstBuffer, DstInMemOp>(
|
||||
src_desc, src_buf, dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v6r1r2<SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
ElementwiseOperation,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmStreamK : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t NumSKBlocks = 0) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceGemmStreamKPtr = std::unique_ptr<DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,357 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
|
||||
struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk<
|
||||
BlockSize,
|
||||
BlockToCTileMap_GemmStreamK<MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock * K1,
|
||||
StreamKReductionStrategy::Atomic>,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
void Print(const Argument& karg) { karg.Print(); }
|
||||
|
||||
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
Print(karg);
|
||||
}
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
|
||||
"setting");
|
||||
}
|
||||
|
||||
dim3 grid_dims = karg.block_mapping.get_grid_dims();
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
|
||||
|
||||
// TODO: remove clear buffer for streamk kernels
|
||||
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
grid_dims,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_workspace_,
|
||||
karg.M,
|
||||
karg.N,
|
||||
karg.K,
|
||||
karg.StrideA,
|
||||
karg.StrideB,
|
||||
karg.StrideC,
|
||||
karg.block_mapping);
|
||||
}
|
||||
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
|
||||
karg.block_mapping.get_workspace_size_for_acc(
|
||||
sizeof(typename GridwiseGemm::FloatAcc));
|
||||
auto preprocess = [&]() {
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(workspace_semaphore,
|
||||
0,
|
||||
karg.block_mapping.get_workspace_size_for_semaphore(),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
|
||||
preprocess,
|
||||
kernel,
|
||||
grid_dims,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_workspace_,
|
||||
karg.M,
|
||||
karg.N,
|
||||
karg.K,
|
||||
karg.StrideA,
|
||||
karg.StrideB,
|
||||
karg.StrideC,
|
||||
karg.block_mapping);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
|
||||
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
uint32_t NumSKBlocks = 0xffffffff)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
|
||||
int occupancy, num_cu;
|
||||
hipError_t rtn;
|
||||
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
|
||||
hip_check_error(rtn);
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
rtn = hipGetDevice(&dev);
|
||||
hip_check_error(rtn);
|
||||
rtn = hipGetDeviceProperties(&dev_prop, dev);
|
||||
hip_check_error(rtn);
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
NumSKBlocks};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
index_t NumSKBlocks = 0) override
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
|
||||
int occupancy, num_cu;
|
||||
hipError_t rtn;
|
||||
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
|
||||
hip_check_error(rtn);
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
rtn = hipGetDevice(&dev);
|
||||
hip_check_error(rtn);
|
||||
rtn = hipGetDeviceProperties(&dev_prop, dev);
|
||||
hip_check_error(rtn);
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
|
||||
reinterpret_cast<const BDataType*>(p_b),
|
||||
reinterpret_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
static_cast<uint32_t>(NumSKBlocks));
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include <limits>
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -669,4 +671,406 @@ struct BlockToCTileMap_3DGrid_KSplit
|
||||
}
|
||||
};
|
||||
|
||||
enum StreamKReductionStrategy
|
||||
{
|
||||
Atomic = 0, // sk block use atomic to do reduction
|
||||
Reduction, // let some workgroup responsible for doing the reduction operation
|
||||
};
|
||||
|
||||
template <uint32_t MPerBlock_,
|
||||
uint32_t NPerBlock_,
|
||||
uint32_t KPerBlock_,
|
||||
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
|
||||
uint32_t TileSwizzleSubM_ = 8>
|
||||
struct BlockToCTileMap_GemmStreamK
|
||||
{
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
|
||||
//--------------------------------------
|
||||
// pass to device
|
||||
uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t reduction_start_block_idx;
|
||||
uint32_t k_iters_per_big_block;
|
||||
MDiv2 n_tiles;
|
||||
MDiv k_iters_per_tile;
|
||||
MDiv eqav_tiles_big; // for reduction
|
||||
MDiv eqav_tiles_little; // for reduction
|
||||
|
||||
// MDiv tile_swizzle_sub_m_rem;
|
||||
//--------------------------------------
|
||||
|
||||
// prefer construct on host
|
||||
BlockToCTileMap_GemmStreamK(uint32_t m,
|
||||
uint32_t n,
|
||||
uint32_t k,
|
||||
uint32_t num_cu,
|
||||
uint32_t occupancy,
|
||||
uint32_t sk_blocks = 0xffffffff)
|
||||
{
|
||||
uint32_t num_tiles =
|
||||
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
|
||||
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
|
||||
|
||||
// one cu can hold one wg at one time, from the whole chip's point of view
|
||||
// if number of wg is same as num_cu, we call it 1 dispatch
|
||||
// if number of wg is 2x num_cu, we call it 2 dispatches.
|
||||
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
|
||||
// dispatch)
|
||||
//
|
||||
uint32_t full_dispatches = num_tiles / num_cu;
|
||||
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
|
||||
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
|
||||
|
||||
uint32_t sk_occupancy = occupancy;
|
||||
uint32_t dp_tiles = full_dispatch_tiles;
|
||||
uint32_t sk_tiles = partial_dispatche_tiles;
|
||||
|
||||
if(full_dispatches < occupancy)
|
||||
{
|
||||
// in this case, we allocate all blocks as sk blocks
|
||||
// sk_occupancy = occupancy - full_dispatches;
|
||||
sk_occupancy = 1; // TODO: single occ seems better
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatche_tiles;
|
||||
}
|
||||
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
|
||||
{
|
||||
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
|
||||
// occupancy = 3, full_dispatches = 5, 8, 11 ...
|
||||
// occupancy = 4, full_dispatches = 7, 11 ...
|
||||
sk_occupancy = 1; // left 1 slot for sk occupancy
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatche_tiles;
|
||||
}
|
||||
else
|
||||
{
|
||||
// others, we reduce 1 dispatch from dp, together with partial dispatch,
|
||||
// to construct sk dispatch
|
||||
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
|
||||
dp_tiles = full_dispatch_tiles - num_cu;
|
||||
sk_tiles = partial_dispatche_tiles + num_cu;
|
||||
}
|
||||
|
||||
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
uint32_t dp_num_blocks = 0;
|
||||
|
||||
{
|
||||
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
|
||||
uint32_t max_sk_tiles =
|
||||
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
|
||||
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
|
||||
|
||||
// if use dp for sk-block, how many iters do we need
|
||||
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
uint32_t tentative_sk_iters_per_block =
|
||||
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
|
||||
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
|
||||
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
|
||||
|
||||
// TODO: carefully adjust this parameter
|
||||
// the more sk_blocks_per_tile, the worse the overhead
|
||||
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
|
||||
if(tentative_sk_blocks % sk_tiles != 0)
|
||||
{
|
||||
// penalty for uneven divide
|
||||
cross_sk_blocks_overhead +=
|
||||
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
|
||||
}
|
||||
|
||||
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
|
||||
|
||||
if(tentative_sk_score < best_sk_score)
|
||||
{
|
||||
best_sk_score = tentative_sk_score;
|
||||
sk_num_blocks = tentative_sk_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_sk_score >= dp_for_sk_iters)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
}
|
||||
|
||||
// give a chance to control num of sk blocks
|
||||
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
|
||||
|
||||
if(sk_num_blocks == 0)
|
||||
{
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
else
|
||||
{
|
||||
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
|
||||
// we need to decide how many iters for each sk block
|
||||
// let m = k_iters_per_sk_block
|
||||
// some of the sk block (little) will cover m iters, some (big) will cover m+1
|
||||
// we have
|
||||
// 1) l + b = sk_blocks
|
||||
// 2) l * m + b * (m + 1) = sk_total_iters
|
||||
// => (l + b) * m + b = sk_total_iters
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
|
||||
}
|
||||
}
|
||||
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
|
||||
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
|
||||
"sk_num_blocks:%d, "
|
||||
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
|
||||
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
|
||||
"sk_tiles:%u, workspace(acc float):%u\n",
|
||||
num_cu,
|
||||
occupancy,
|
||||
get_grid_dims().x,
|
||||
num_tiles,
|
||||
dp_tiles,
|
||||
sk_num_big_blocks,
|
||||
sk_num_blocks,
|
||||
sk_total_iters,
|
||||
dp_start_block_idx,
|
||||
dp_iters_per_block,
|
||||
dp_num_blocks,
|
||||
k_iters_per_tile.get(),
|
||||
k_iters_per_big_block,
|
||||
reduction_start_block_idx,
|
||||
get_sk_tiles(),
|
||||
get_workspace_size(sizeof(float)));
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_sk_total_iters() const
|
||||
{
|
||||
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
|
||||
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
return sk_total_iters;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_sk_tiles() const
|
||||
{
|
||||
// tiles for sk
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
return k_iters_per_tile.div(sk_total_iters);
|
||||
}
|
||||
|
||||
__host__ __device__ dim3 get_grid_dims() const
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
|
||||
}
|
||||
else
|
||||
return dim3(reduction_start_block_idx, 1, 1);
|
||||
}
|
||||
|
||||
__device__ uint32_t get_block_idx() const
|
||||
{
|
||||
// TODO: swizzle block index for better locality
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
}
|
||||
|
||||
__device__ void
|
||||
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t current_iter_length = math::min(
|
||||
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
|
||||
return current_iter_length;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
|
||||
|
||||
__device__ void
|
||||
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
|
||||
{
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
|
||||
{
|
||||
uint32_t m_tile_idx, n_tile_idx;
|
||||
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
|
||||
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
|
||||
|
||||
// swizzle tile
|
||||
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
|
||||
|
||||
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
|
||||
|
||||
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
|
||||
? tile_swizzle_sub_m
|
||||
: tile_swizzle_sub_m_rem;
|
||||
|
||||
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
|
||||
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
|
||||
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
|
||||
|
||||
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
|
||||
|
||||
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
|
||||
|
||||
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
|
||||
n_tile_idx_with_adapt);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
|
||||
{
|
||||
static constexpr uint32_t alignment = 128;
|
||||
uint32_t acc_buffer_bytes =
|
||||
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
|
||||
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
|
||||
{
|
||||
return get_sk_tiles() * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
|
||||
{
|
||||
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
|
||||
const MDiv& eqav_tiles_) const
|
||||
{
|
||||
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
|
||||
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
|
||||
uint32_t quo_, rem_;
|
||||
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
|
||||
return quo_ * max_eqav_tiles_ + rem_;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
|
||||
uint32_t iters_per_sk_block_) const
|
||||
{
|
||||
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
|
||||
1);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_total_acc_buffers() const
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
uint32_t tiles_cover_little_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
|
||||
|
||||
uint32_t total_intersec_big =
|
||||
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
|
||||
uint32_t total_intersec_little =
|
||||
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
|
||||
|
||||
return sk_num_blocks + total_intersec_big + total_intersec_little;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
|
||||
{
|
||||
// TODO: from big to little
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
if(tile_idx_ < tiles_cover_big_blocks)
|
||||
{
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
|
||||
k_iters_per_big_block;
|
||||
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
|
||||
return touched_sk_blocks + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
|
||||
iters_per_little_sk_block;
|
||||
uint32_t current_intersec =
|
||||
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
|
||||
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
|
||||
{
|
||||
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
if(block_idx_ < sk_num_big_blocks)
|
||||
{
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
|
||||
k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
|
||||
return block_idx_ + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(
|
||||
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
|
||||
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
struct GridwiseGemmPipeline_v3
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t)
|
||||
{
|
||||
// TODO: improve applicability
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// global read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// LDS write 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
num_loop--;
|
||||
|
||||
while(num_loop > 0)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
block_sync_lds();
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
num_loop--;
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
bool DstResetCoordinateAfterRun>
|
||||
struct ThreadwiseTensorSliceTransfer_v6r1r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
element_op_(element_op)
|
||||
{
|
||||
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
// loop over space-filling curve
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, ScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
// copy data from src_buf into src_vector_container
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
|
||||
auto dst_vector_container = dst_vector_type{};
|
||||
|
||||
// apply pointwise operation
|
||||
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// move coordinate
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCoordinateResetStep()
|
||||
{
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
|
||||
? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx = DstResetCoordinateAfterRun
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -629,7 +629,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
@@ -682,6 +682,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<float, 8> tmp{src_thread_data};
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
|
||||
@@ -157,4 +157,76 @@ struct MagicDivision
|
||||
}
|
||||
};
|
||||
|
||||
struct MDiv
|
||||
{
|
||||
// 1 dword -> 3 dword storage
|
||||
uint32_t divisor;
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
__host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
|
||||
{
|
||||
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
|
||||
|
||||
multiplier = tmp[Number<0>{}];
|
||||
shift = tmp[Number<1>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ MDiv() : divisor(0), multiplier(0), shift(0) {}
|
||||
|
||||
__host__ __device__ void update(uint32_t divisor_)
|
||||
{
|
||||
divisor = divisor_;
|
||||
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
|
||||
|
||||
multiplier = tmp[Number<0>{}];
|
||||
shift = tmp[Number<1>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get() const { return divisor; }
|
||||
};
|
||||
|
||||
struct MDiv2
|
||||
{
|
||||
// 1 dword -> 2 dword storage, divisor need compute from runtime
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
__host__ __device__ MDiv2(uint32_t divisor_)
|
||||
{
|
||||
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
|
||||
|
||||
multiplier = tmp[Number<0>{}];
|
||||
shift = tmp[Number<1>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ MDiv2() : multiplier(0), shift(0) {}
|
||||
|
||||
__host__ __device__ uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -240,5 +240,21 @@ struct less
|
||||
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
|
||||
};
|
||||
|
||||
template <index_t X>
|
||||
__host__ __device__ constexpr auto next_power_of_two()
|
||||
{
|
||||
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
||||
constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
|
||||
return Y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
__host__ __device__ constexpr auto next_power_of_two(Number<X> x)
|
||||
{
|
||||
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
||||
constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
|
||||
return Number<Y>{};
|
||||
}
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
73
include/ck/utility/workgroup_barrier.hpp
Normal file
73
include/ck/utility/workgroup_barrier.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
#pragma once
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck {
|
||||
struct workgroup_barrier
|
||||
{
|
||||
__device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
|
||||
|
||||
__device__ uint32_t ld(uint32_t offset)
|
||||
{
|
||||
#if 0
|
||||
float d = llvm_amdgcn_raw_buffer_load_fp32(
|
||||
amdgcn_make_buffer_resource(base_ptr),
|
||||
0,
|
||||
offset,
|
||||
AMDGCN_BUFFER_GLC);
|
||||
union cvt {
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
};
|
||||
cvt x;
|
||||
x.f32 = d;
|
||||
return x.u32;
|
||||
#endif
|
||||
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
|
||||
}
|
||||
|
||||
__device__ void wait_eq(uint32_t offset, uint32_t value)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) != value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void wait_lt(uint32_t offset, uint32_t value)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) < value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void wait_set(uint32_t offset, uint32_t compare, uint32_t value)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// enter critical zoon, assume buffer is zero when launch kernel
|
||||
__device__ void aquire(uint32_t offset) { wait_set(offset, 0, 1); }
|
||||
|
||||
// exit critical zoon, assume buffer is zero when launch kernel
|
||||
__device__ void release(uint32_t offset) { wait_set(offset, 1, 0); }
|
||||
|
||||
__device__ void inc(uint32_t offset)
|
||||
{
|
||||
__syncthreads();
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(base_ptr + offset, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t* base_ptr;
|
||||
};
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,121 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmStreamK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#if 0
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
*/
|
||||
struct DeviceMem
|
||||
{
|
||||
DeviceMem() = delete;
|
||||
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
|
||||
DeviceMem(std::size_t mem_size);
|
||||
void Realloc(std::size_t mem_size);
|
||||
void* GetDeviceBuffer() const;
|
||||
std::size_t GetBufferSize() const;
|
||||
void ToDevice(const void* p) const;
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
// static constexpr auto GemmMNPadding =
|
||||
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,20 +10,57 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void DeviceMem::Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipFree(mpDeviceBuf));
|
||||
}
|
||||
mMemSize = mem_size;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
|
||||
|
||||
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("ToDevice with an empty pointer");
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("FromDevice with an empty pointer");
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
|
||||
void DeviceMem::SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize));
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
|
||||
DeviceMem::~DeviceMem()
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipFree(mpDeviceBuf));
|
||||
}
|
||||
}
|
||||
|
||||
265
profiler/include/profiler/profile_gemm_streamk_impl.hpp
Normal file
265
profiler/include/profiler/profile_gemm_streamk_impl.hpp
Normal file
@@ -0,0 +1,265 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool profile_gemm_streamk_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
uint32_t NumSKBlocks = 0xffffffff)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-3, 3});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances, "
|
||||
<< (do_verification ? "with verification" : "without verification") << std::endl;
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
NumSKBlocks);
|
||||
DeviceMem workspace;
|
||||
std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
if(workspace_size != 0)
|
||||
{
|
||||
workspace.Realloc(workspace_size);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<CDataType, float>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, half_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = bf16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = int8";
|
||||
}
|
||||
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
|
||||
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -3,6 +3,7 @@ set(PROFILER_SOURCES
|
||||
profiler.cpp
|
||||
profile_gemm.cpp
|
||||
profile_gemm_splitk.cpp
|
||||
profile_gemm_streamk.cpp
|
||||
profile_gemm_bilinear.cpp
|
||||
profile_gemm_bias_add_reduce.cpp
|
||||
profile_gemm_add_add_fastgelu.cpp
|
||||
@@ -48,6 +49,7 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
|
||||
|
||||
155
profiler/src/profile_gemm_streamk.cpp
Normal file
155
profiler/src/profile_gemm_streamk.cpp
Normal file
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/profile_gemm_streamk_impl.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_streamk"
|
||||
#define OP_DESC "StreamK GEMM"
|
||||
|
||||
int profile_gemm_streamk(int argc, char* argv[])
|
||||
{
|
||||
if(argc < 14)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
printf("arg14: num_sk_blocks (optional)\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const int init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
|
||||
const int M = std::stoi(argv[8]);
|
||||
const int N = std::stoi(argv[9]);
|
||||
const int K = std::stoi(argv[10]);
|
||||
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
const uint32_t NumSKBlocks =
|
||||
argc >= 15 ? static_cast<uint32_t>(std::stoul(std::string(argv[14]))) : 0xffffffff;
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
auto profile = [&](auto a_type,
|
||||
auto b_type,
|
||||
auto acc_type,
|
||||
auto c_type,
|
||||
auto a_layout,
|
||||
auto b_layout,
|
||||
auto c_layout) {
|
||||
using ADataType = decltype(a_type);
|
||||
using BDataType = decltype(b_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
using CDataType = decltype(c_type);
|
||||
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CLayout = decltype(c_layout);
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
bool pass = ck::profiler::profile_gemm_streamk_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA <= 0) ? DefaultStrideA : StrideA,
|
||||
(StrideB <= 0) ? DefaultStrideB : StrideB,
|
||||
(StrideC <= 0) ? DefaultStrideC : StrideC,
|
||||
NumSKBlocks);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_streamk);
|
||||
406
test/block_swizzle_test/block_swizzle_test.cpp
Normal file
406
test/block_swizzle_test/block_swizzle_test.cpp
Normal file
@@ -0,0 +1,406 @@
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "simple_args.h"
|
||||
|
||||
simple_args_t create_arg(int argc, char** argv)
|
||||
{
|
||||
simple_args_t args;
|
||||
args.insert("m", "1024", "matrix m")
|
||||
.insert("n", "1024", "matrix n")
|
||||
.insert("k", "1024", "matrix k")
|
||||
.insert("m_per_block", "128", "m_per_block")
|
||||
.insert("n_per_block", "128", "n_per_block")
|
||||
.insert("k_per_block", "32", "k_per_block")
|
||||
.insert("num_cu", "104", "num cu")
|
||||
.insert("occupancy", "2", "occupancy")
|
||||
.parse(argc, argv);
|
||||
return args;
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
T integer_divide_ceil(T n, T d)
|
||||
{
|
||||
return (n + d - 1) / d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T min(T a, T b)
|
||||
{
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T max(T a, T b)
|
||||
{
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
struct block_dispatcher_t
|
||||
{
|
||||
public:
|
||||
uint32_t m_per_block;
|
||||
uint32_t n_per_block;
|
||||
uint32_t k_per_block;
|
||||
uint32_t num_cu;
|
||||
uint32_t occupancy;
|
||||
uint32_t m;
|
||||
uint32_t n;
|
||||
uint32_t k;
|
||||
|
||||
//--------------------------------------
|
||||
|
||||
uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t sk_total_iters;
|
||||
|
||||
// uint32_t sk_num_blocks_per_tile; // how many
|
||||
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t dp_iters_per_block;
|
||||
uint32_t dp_num_blocks;
|
||||
|
||||
uint32_t k_iters_per_tile;
|
||||
uint32_t k_iters_per_big_block;
|
||||
//--------------------------------------
|
||||
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 1;
|
||||
|
||||
void dump()
|
||||
{
|
||||
printf("%dx%dx%d(%dx%dx%d), cu:%d, occ:%d, grids:%d, sk_num_big_blocks:%d, "
|
||||
"sk_num_blocks:%d, sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, "
|
||||
"dp_num_blocks:%d, k_iters_per_tile:%d, k_iters_per_big_block:%d\n",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
m_per_block,
|
||||
n_per_block,
|
||||
k_per_block,
|
||||
num_cu,
|
||||
occupancy,
|
||||
get_grid_dims_x(),
|
||||
sk_num_big_blocks,
|
||||
sk_num_blocks,
|
||||
sk_total_iters,
|
||||
dp_start_block_idx,
|
||||
dp_iters_per_block,
|
||||
dp_num_blocks,
|
||||
k_iters_per_tile,
|
||||
k_iters_per_big_block);
|
||||
}
|
||||
|
||||
block_dispatcher_t(uint32_t m_per_block_,
|
||||
uint32_t n_per_block_,
|
||||
uint32_t k_per_block_,
|
||||
uint32_t num_cu_,
|
||||
uint32_t occupancy_,
|
||||
uint32_t m_,
|
||||
uint32_t n_,
|
||||
uint32_t k_)
|
||||
: m_per_block(m_per_block_),
|
||||
n_per_block(n_per_block_),
|
||||
k_per_block(k_per_block_),
|
||||
num_cu(num_cu_),
|
||||
occupancy(occupancy_),
|
||||
m(m_),
|
||||
n(n_),
|
||||
k(k_)
|
||||
{
|
||||
init();
|
||||
}
|
||||
|
||||
uint32_t get_grid_dims_x() { return dp_start_block_idx + dp_num_blocks; }
|
||||
|
||||
uint32_t get_block_idx(uint32_t bid)
|
||||
{
|
||||
// block id is linearily allocated along sk blocks (dp blocks are fine)
|
||||
// this function will compute blockIdx.x and the linear sk block mapping
|
||||
// uint32_t block_idx = 0;
|
||||
// if(bid < sk_num_big_blocks) {
|
||||
// uint32_t current_k_iter = bid * k_iters_per_big_block;
|
||||
// tile_idx = current_k_iter / k_iters_per_tile;
|
||||
// }
|
||||
return bid;
|
||||
}
|
||||
|
||||
uint32_t get_current_itr(uint32_t block_idx)
|
||||
{
|
||||
uint32_t current_itr = 0;
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
current_itr = block_idx * k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
current_itr = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
current_itr = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
}
|
||||
return current_itr;
|
||||
}
|
||||
|
||||
void get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end)
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void init()
|
||||
{
|
||||
uint32_t num_tiles =
|
||||
impl::integer_divide_ceil(m, m_per_block) * impl::integer_divide_ceil(n, n_per_block);
|
||||
k_iters_per_tile = impl::integer_divide_ceil(k, k_per_block);
|
||||
|
||||
// one cu can hold one wg at one time, from the whole chip's point of view
|
||||
// if number of wg is same as num_cu, we call it 1 dispatch
|
||||
// if number of wg is 2x num_cu, we call it 2 dispatches.
|
||||
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
|
||||
// dispatch)
|
||||
//
|
||||
uint32_t full_dispatches = num_tiles / num_cu;
|
||||
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
|
||||
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
|
||||
|
||||
uint32_t sk_occupancy = occupancy;
|
||||
uint32_t dp_tiles = full_dispatch_tiles;
|
||||
uint32_t sk_tiles = partial_dispatche_tiles;
|
||||
|
||||
if(full_dispatches < occupancy)
|
||||
{
|
||||
// in this case, we allocate all blocks as sk blocks
|
||||
// sk_occupancy = occupancy - full_dispatches;
|
||||
sk_occupancy = 1; // TODO: single occ seems better
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatche_tiles;
|
||||
}
|
||||
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
|
||||
{
|
||||
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
|
||||
// occupancy = 3, full_dispatches = 5, 8, 11 ...
|
||||
// occupancy = 4, full_dispatches = 7, 11 ...
|
||||
sk_occupancy = 1; // left 1 slot for sk occupancy
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatche_tiles;
|
||||
}
|
||||
else
|
||||
{
|
||||
// others, we reduce 1 dispatch from dp, together with partial dispatch,
|
||||
// to construct sk dispatch
|
||||
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
|
||||
dp_tiles = full_dispatch_tiles - num_cu;
|
||||
sk_tiles = partial_dispatche_tiles + num_cu;
|
||||
}
|
||||
|
||||
// dp_num_blocks = dp_tiles;
|
||||
// dp_start_block_idx = num_cu * sk_occupancy;
|
||||
dp_iters_per_block = k_iters_per_tile;
|
||||
|
||||
sk_total_iters = k_iters_per_tile * sk_tiles;
|
||||
|
||||
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
|
||||
// partial_dispatche_tiles:%d\n",
|
||||
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
|
||||
|
||||
{
|
||||
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
|
||||
uint32_t max_sk_tiles =
|
||||
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
|
||||
: impl::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
|
||||
|
||||
// if use dp for sk-block, how many iters do we need
|
||||
uint32_t dp_for_sk_iters = k_iters_per_tile;
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
uint32_t tentative_sk_iters_per_block =
|
||||
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
|
||||
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
|
||||
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
|
||||
|
||||
// TODO: carefully adjust this parameter
|
||||
// the more sk_blocks_per_tile, the worse the overhead
|
||||
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
|
||||
if(tentative_sk_blocks % sk_tiles != 0)
|
||||
{
|
||||
// penalty for uneven divide
|
||||
cross_sk_blocks_overhead +=
|
||||
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
|
||||
}
|
||||
|
||||
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
|
||||
|
||||
if(tentative_sk_score < best_sk_score)
|
||||
{
|
||||
best_sk_score = tentative_sk_score;
|
||||
sk_num_blocks = tentative_sk_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_sk_score >= dp_for_sk_iters)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
}
|
||||
|
||||
if(sk_num_blocks == 0)
|
||||
{
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct tile_work_t
|
||||
{
|
||||
uint32_t tile_idx;
|
||||
uint32_t iter_begin;
|
||||
uint32_t k_begin;
|
||||
uint32_t k_end;
|
||||
uint32_t k_iters_remaining;
|
||||
};
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
simple_args_t arg = create_arg(argc, argv);
|
||||
block_dispatcher_t block_dispatcher{arg.get_uint32("m_per_block"),
|
||||
arg.get_uint32("n_per_block"),
|
||||
arg.get_uint32("k_per_block"),
|
||||
arg.get_uint32("num_cu"),
|
||||
arg.get_uint32("occupancy"),
|
||||
arg.get_uint32("m"),
|
||||
arg.get_uint32("n"),
|
||||
arg.get_uint32("k")};
|
||||
block_dispatcher.dump();
|
||||
// simulate actual kernel launch
|
||||
uint32_t dim_x = block_dispatcher.get_grid_dims_x();
|
||||
uint32_t total_k_iters =
|
||||
impl::integer_divide_ceil(arg.get_uint32("k"), arg.get_uint32("k_per_block"));
|
||||
uint32_t num_tiles =
|
||||
impl::integer_divide_ceil(arg.get_uint32("m"), arg.get_uint32("m_per_block")) *
|
||||
impl::integer_divide_ceil(arg.get_uint32("n"), arg.get_uint32("n_per_block"));
|
||||
|
||||
std::vector<int> valid_tile_record(num_tiles * total_k_iters);
|
||||
|
||||
for(uint32_t bid = 0; bid < dim_x; bid++)
|
||||
{
|
||||
uint32_t block_idx = block_dispatcher.get_block_idx(bid);
|
||||
bool is_sk_block = block_idx < (block_dispatcher.sk_num_blocks);
|
||||
bool is_dp_block = block_idx >= block_dispatcher.dp_start_block_idx;
|
||||
uint32_t iter_start, iter_end;
|
||||
block_dispatcher.get_block_itr(block_idx, iter_start, iter_end);
|
||||
uint32_t total_iter_length = iter_end - iter_start;
|
||||
|
||||
while(true)
|
||||
{
|
||||
uint32_t iter_length_mod = iter_end % block_dispatcher.k_iters_per_tile;
|
||||
uint32_t current_iter_length =
|
||||
impl::min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
|
||||
total_iter_length);
|
||||
uint32_t tile_idx = (iter_end - 1) / block_dispatcher.k_iters_per_tile;
|
||||
uint32_t tile_iter_start =
|
||||
((iter_end - 1) % block_dispatcher.k_iters_per_tile) - current_iter_length + 1;
|
||||
|
||||
if(is_sk_block)
|
||||
{
|
||||
printf("[sk_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
|
||||
"iter_end:%d (len:%d)\n",
|
||||
bid,
|
||||
block_idx,
|
||||
tile_idx,
|
||||
iter_end - current_iter_length,
|
||||
tile_iter_start,
|
||||
iter_start,
|
||||
iter_end,
|
||||
current_iter_length);
|
||||
}
|
||||
else if(is_dp_block)
|
||||
{
|
||||
printf("[dp_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
|
||||
"iter_end:%d (len:%d)\n",
|
||||
bid,
|
||||
block_idx,
|
||||
tile_idx,
|
||||
iter_end - current_iter_length,
|
||||
tile_iter_start,
|
||||
iter_start,
|
||||
iter_end,
|
||||
current_iter_length);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("[other ] bid:%3d, block_idx:%3d\n", bid, block_idx);
|
||||
}
|
||||
|
||||
// some validation check
|
||||
for(auto i = iter_end - current_iter_length; i < iter_end; i++)
|
||||
{
|
||||
if(i >= valid_tile_record.size())
|
||||
{
|
||||
printf("unexpected, current iter:%d larger than max:%d\n",
|
||||
i,
|
||||
valid_tile_record.size());
|
||||
return -1;
|
||||
}
|
||||
valid_tile_record[i] = 1;
|
||||
}
|
||||
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int untouched = 0;
|
||||
for(auto i = 0; i < valid_tile_record.size(); i++)
|
||||
{
|
||||
if(valid_tile_record[i] != 1)
|
||||
{
|
||||
printf("untouched at %d (%d)\n", i, valid_tile_record.size());
|
||||
untouched++;
|
||||
}
|
||||
}
|
||||
printf("untouched %d/%d, %s\n",
|
||||
untouched,
|
||||
valid_tile_record.size(),
|
||||
untouched == 0 ? "valid" : "fail");
|
||||
}
|
||||
3
test/block_swizzle_test/rebuild.sh
Normal file
3
test/block_swizzle_test/rebuild.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
CC=g++
|
||||
|
||||
$CC -Wall -std=c++17 -Iinclude -O3 block_swizzle_test.cpp -o block_swizzle_test.exe
|
||||
159
test/block_swizzle_test/simple_args.h
Normal file
159
test/block_swizzle_test/simple_args.h
Normal file
@@ -0,0 +1,159 @@
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <assert.h>
|
||||
|
||||
struct arg_content_t
|
||||
{
|
||||
std::string name; // key
|
||||
std::string value;
|
||||
std::string help_text;
|
||||
};
|
||||
|
||||
class simple_args_t
|
||||
{
|
||||
public:
|
||||
simple_args_t() {}
|
||||
simple_args_t& insert(const std::string& name_,
|
||||
const std::string& default_value_,
|
||||
const std::string& help_text_)
|
||||
{
|
||||
arg_content_t arg{name_, default_value_, help_text_};
|
||||
|
||||
if(arg_map.count(arg.name) != 0)
|
||||
{
|
||||
std::cout << "arg:" << arg.name << "already exist" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
arg_map[arg.name] = arg;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void usage()
|
||||
{
|
||||
for(auto& content : arg_map)
|
||||
{
|
||||
std::vector<std::string> help_text_lines;
|
||||
size_t pos = 0;
|
||||
for(size_t next_pos = content.second.help_text.find('\n', pos);
|
||||
next_pos != std::string::npos;)
|
||||
{
|
||||
help_text_lines.push_back(
|
||||
std::string(content.second.help_text.begin() + pos,
|
||||
content.second.help_text.begin() + next_pos++));
|
||||
pos = next_pos;
|
||||
next_pos = content.second.help_text.find('\n', pos);
|
||||
}
|
||||
help_text_lines.push_back(std::string(content.second.help_text.begin() + pos,
|
||||
content.second.help_text.end()));
|
||||
|
||||
int arg_name_width = 16 - content.second.name.length();
|
||||
arg_name_width = arg_name_width > 0 ? arg_name_width : 2;
|
||||
std::cout << std::setw(4) << "-" << content.second.name << std::setw(arg_name_width)
|
||||
<< " " << help_text_lines[0] << std::endl;
|
||||
|
||||
for(auto help_next_line = std::next(help_text_lines.begin());
|
||||
help_next_line != help_text_lines.end();
|
||||
++help_next_line)
|
||||
{
|
||||
std::cout << std::setw(28) << " " << *help_next_line << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool parse(int argc, char* argv[], int start_index = 1)
|
||||
{
|
||||
if(argc <= start_index)
|
||||
{
|
||||
// std::cout << "not enough args (" << argc << ") with starting index " << start_index
|
||||
// << std::endl;
|
||||
return true;
|
||||
}
|
||||
for(int i = start_index; i < argc; i++)
|
||||
{
|
||||
std::string cur_arg = std::string(argv[i]);
|
||||
if(cur_arg[0] != '-')
|
||||
{
|
||||
std::cout << "illegal input" << std::endl;
|
||||
usage();
|
||||
return false;
|
||||
}
|
||||
else if(cur_arg[0] == '-' && cur_arg[1] == '?')
|
||||
{
|
||||
usage();
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t found_equal = cur_arg.find('=');
|
||||
if(found_equal == std::string::npos || found_equal == (cur_arg.length() - 1))
|
||||
{
|
||||
std::cout << "failed while parsing \"" << cur_arg << "\", "
|
||||
<< "arg must be in the form \"-name=value\"" << std::endl;
|
||||
return false;
|
||||
}
|
||||
std::string arg_name = cur_arg.substr(1, found_equal - 1);
|
||||
std::string arg_value = cur_arg.substr(found_equal + 1);
|
||||
if(arg_map.count(arg_name) == 0)
|
||||
{
|
||||
std::cout << "no such arg \"" << arg_name << "\" registered" << std::endl;
|
||||
return false;
|
||||
}
|
||||
arg_map[arg_name].value = arg_value;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string get(const std::string& name) const { return get_str(name); }
|
||||
|
||||
std::string get_str(const std::string& name) const
|
||||
{
|
||||
assert(arg_map.count(name) != 0);
|
||||
std::string value = arg_map.at(name).value;
|
||||
return value;
|
||||
}
|
||||
|
||||
int get_int(const std::string& name) const
|
||||
{
|
||||
assert(arg_map.count(name) != 0);
|
||||
int value = atoi(arg_map.at(name).value.c_str());
|
||||
return value;
|
||||
}
|
||||
|
||||
uint32_t get_uint32(const std::string& name) const
|
||||
{
|
||||
assert(arg_map.count(name) != 0);
|
||||
uint32_t value = strtoul(arg_map.at(name).value.c_str(), nullptr, 10);
|
||||
return value;
|
||||
}
|
||||
|
||||
uint64_t get_uint64(const std::string& name) const
|
||||
{
|
||||
assert(arg_map.count(name) != 0);
|
||||
uint64_t value = strtoull(arg_map.at(name).value.c_str(), nullptr, 10);
|
||||
return value;
|
||||
}
|
||||
|
||||
double get_double(const std::string& name) const
|
||||
{
|
||||
assert(arg_map.count(name) != 0);
|
||||
double value = atof(arg_map.at(name).value.c_str());
|
||||
return value;
|
||||
}
|
||||
|
||||
float get_float(const std::string& name) const
|
||||
{
|
||||
assert(arg_map.count(name) != 0);
|
||||
float value = atof(arg_map.at(name).value.c_str());
|
||||
return value;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, arg_content_t> arg_map;
|
||||
};
|
||||
Reference in New Issue
Block a user