Add host API (#220)

* Add host API

* manually rebase on develop

* clean

* manually rebase on develop

* exclude tests from all target

* address review comments

* update client app name

* fix missing lib name

* clang-format update

* refactor

* refactor

* refactor

* refactor

* refactor

* fix test issue

* refactor

* refactor

* refactor

* upate cmake and readme

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
JD
2022-05-12 09:21:01 -05:00
committed by GitHub
parent 0f912e205e
commit cec69bc3bc
131 changed files with 1666 additions and 1089 deletions

View File

@@ -1,8 +1,9 @@
#ifndef DEVICE_BASE_HPP
#define DEVICE_BASE_HPP
#pragma once
#include <string>
#include "stream_config.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
@@ -22,7 +23,10 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0;
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {}
};
@@ -33,8 +37,8 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) = 0;
virtual std::string GetTypeString() const = 0;
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual ~BaseOperator() {}
};
@@ -42,4 +46,3 @@ struct BaseOperator
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -729,6 +729,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
@@ -748,26 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>,
true>;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
else
{
@@ -788,35 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>,
false>;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
return 0;
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -428,7 +428,7 @@ struct DeviceBatchedGemmXdl
{
using Argument = DeviceBatchedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
@@ -602,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
true>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -635,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
false>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -655,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -642,7 +642,7 @@ struct
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -727,8 +727,8 @@ struct
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -771,8 +771,8 @@ struct
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -795,9 +795,10 @@ struct
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -605,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -684,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -723,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -745,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -663,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -697,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -717,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -498,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -529,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -549,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
WeiElementwiseOperation,
OutElementwiseOperation>;
float ave_time = launch_and_time_kernel(naive_conv3d_fwd,
nrepeat,
float ave_time = launch_and_time_kernel(stream_config,
naive_conv3d_fwd,
dim3(256),
dim3(256),
0,
@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
OutElementwiseOperation,
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -1241,7 +1241,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
@@ -1316,8 +1316,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
true>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -1349,8 +1349,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
false>;
ave_time += launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -1369,9 +1369,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -747,7 +747,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -795,8 +795,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -826,8 +826,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -846,9 +846,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -503,7 +503,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -536,6 +536,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
@@ -554,24 +555,26 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
else
{
@@ -591,33 +594,36 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
return 0;
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -290,7 +290,7 @@ struct DeviceGemmXdl
{
using Argument = DeviceGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -339,8 +339,8 @@ struct DeviceGemmXdl
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -370,8 +370,8 @@ struct DeviceGemmXdl
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -391,9 +391,10 @@ struct DeviceGemmXdl
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -264,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
{
using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
@@ -320,8 +320,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -359,8 +359,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -382,9 +382,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -273,7 +273,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
@@ -329,8 +329,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -368,8 +368,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -391,9 +391,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -312,7 +312,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
@@ -374,8 +374,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -418,8 +418,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -443,9 +443,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -440,7 +440,7 @@ struct DeviceGemm_Xdl_CShuffle
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
@@ -487,42 +487,22 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
@@ -538,52 +518,32 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -385,8 +385,11 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ShowInfo(arg);
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
// FIXME: this should be moved outside of DeviceOp
hipGetErrorString(
hipMemset(arg.p_c_grid_,
0,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
sizeof(CDataType)));
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(
hipMemset(arg.p_c_grid_,
0,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -391,8 +391,11 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
@@ -414,51 +417,29 @@ struct DeviceGemmXdlSplitKCShuffle
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ShowInfo(arg);
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
@@ -542,9 +523,10 @@ struct DeviceGemmXdlSplitKCShuffle
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -449,7 +449,7 @@ struct DeviceGroupedGemmXdl
{
using Argument = DeviceGroupedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
@@ -510,8 +510,8 @@ struct DeviceGroupedGemmXdl
true,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
@@ -534,8 +534,8 @@ struct DeviceGroupedGemmXdl
false,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
@@ -550,9 +550,10 @@ struct DeviceGroupedGemmXdl
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -204,7 +204,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
@@ -241,8 +241,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(kernel,
nrepeat,
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -257,9 +257,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};

View File

@@ -211,7 +211,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
@@ -253,8 +253,8 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
@@ -272,9 +272,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};

View File

@@ -182,7 +182,7 @@ struct DeviceReduceBlockWiseSecondCall
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_);
@@ -224,8 +224,8 @@ struct DeviceReduceBlockWiseSecondCall
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
@@ -243,10 +243,11 @@ struct DeviceReduceBlockWiseSecondCall
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -245,7 +245,7 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
@@ -275,8 +275,6 @@ struct DeviceReduceMultiBlockAtomicAdd
float avg_time = 0;
KernelTimer timer;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
InDataType,
@@ -287,50 +285,38 @@ struct DeviceReduceMultiBlockAtomicAdd
InElementwiseOperation,
AccElementwiseOperation>;
printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n",
arg.gridSize,
BlockSize);
printf("Warm up\n");
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
out_grid_desc_m,
arg.out_dev_,
static_cast<OutDataType>(0.0f));
for(int i = 0; i < nrepeat + 1; i++)
{
if(i == 1)
timer.Start();
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.alpha_,
arg.in_dev_,
arg.out_dev_);
launch_kernel(kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
out_grid_desc_m,
arg.out_dev_,
static_cast<OutDataType>(0.0f));
return avg_time;
}
launch_kernel(kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.alpha_,
arg.in_dev_,
arg.out_dev_);
};
timer.End();
avg_time = timer.GetElapsedTime() / nrepeat;
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -273,7 +273,7 @@ struct DeviceReduceMultiBlockPartialReduce
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
@@ -313,8 +313,8 @@ struct DeviceReduceMultiBlockPartialReduce
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
@@ -331,10 +331,11 @@ struct DeviceReduceMultiBlockPartialReduce
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -212,7 +212,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
@@ -254,8 +254,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InElementwiseOperation,
OutElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
@@ -272,10 +272,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override