mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Fix performance issue when passing tensor descriptor from host to kernel by void pointers (#27)
* use address_space(4) in kernel signature to fix performance issue when passing tensor descriptor from host to kernel by (void) pointers * remove passing by pointer* option (only use pass by value or void*)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -363,171 +363,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER
|
||||
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
|
||||
using BDesc = decltype(in_gemmk_gemmn_global_desc);
|
||||
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
|
||||
|
||||
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
|
||||
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
|
||||
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
|
||||
|
||||
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
|
||||
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
|
||||
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
@@ -561,111 +396,115 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
true,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
true,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
false,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
false,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1017,171 +856,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER
|
||||
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
|
||||
using BDesc = decltype(in_gemmk_gemmn_global_desc);
|
||||
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
|
||||
|
||||
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
|
||||
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
|
||||
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
|
||||
|
||||
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
|
||||
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
|
||||
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_gemmn_global_desc)*,
|
||||
const FloatAB*,
|
||||
decltype(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
reinterpret_cast<const ADesc*>(
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_wei_global,
|
||||
reinterpret_cast<const BDesc*>(
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
|
||||
p_in_global,
|
||||
reinterpret_cast<const CDesc*>(
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer()),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
@@ -1215,114 +889,117 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
true,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
true,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
false,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
const FloatAB*,
|
||||
const void*,
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
|
||||
ADesc,
|
||||
FloatAB,
|
||||
BDesc,
|
||||
FloatAB,
|
||||
CDesc,
|
||||
FloatC,
|
||||
false,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
launch_kernel(
|
||||
kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)
|
||||
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_wei_global,
|
||||
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_in_global,
|
||||
(void __CONSTANT__*)
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
|
||||
.GetDeviceBuffer(),
|
||||
p_out_global);
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
@@ -11,6 +11,47 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename AGlobalDesc,
|
||||
typename FloatA,
|
||||
typename BGlobalDesc,
|
||||
typename FloatB,
|
||||
typename CGlobalDesc,
|
||||
typename FloatC,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const FloatA* __restrict__ p_a_global,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global)
|
||||
{
|
||||
// first cast void __CONSTANT__* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k_m_global_desc =
|
||||
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
|
||||
const auto b_k_n_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
|
||||
const auto c_m0_m1_n0_n1_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
|
||||
|
||||
GridwiseGemm{}.Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
@@ -427,7 +468,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
}
|
||||
}
|
||||
|
||||
// pass tensor descriptor by reference
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
@@ -452,57 +492,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by pointers
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc* p_b_k_n_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_k_m_global_desc = *p_a_k_m_global_desc;
|
||||
const auto b_k_n_global_desc = *p_b_k_n_global_desc;
|
||||
const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc;
|
||||
|
||||
Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by void*
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const void* p_a_k_m_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const void* p_b_k_n_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const void* p_c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc);
|
||||
const auto b_k_n_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_k_n_global_desc);
|
||||
const auto c_m0_m1_n0_n1_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>(p_c_m0_m1_n0_n1_global_desc);
|
||||
|
||||
Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#endif
|
||||
#include "bfloat16_dev.hpp"
|
||||
|
||||
// address space for kernel parameter
|
||||
#define __CONSTANT__ __attribute__((address_space(4)))
|
||||
|
||||
// device backend
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
|
||||
@@ -108,9 +111,8 @@
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value, pointer or void*
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER 0
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
|
||||
Reference in New Issue
Block a user