mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
add GetWorkSpaceSize to base arg (#253)
* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight
* remove redundant compute
* use datatype and split k to check whether a workspace is used
* remove unused computation for work space size
[ROCm/composable_kernel commit: 0d08cf1893]
This commit is contained in:
@@ -257,11 +257,11 @@ int main(int argc, char* argv[])
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
@@ -296,15 +296,53 @@ int main(int argc, char* argv[])
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
// alloc work space
|
||||
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
|
||||
float ave_time = 0.f;
|
||||
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
|
||||
wei_work_space_device_buf.SetZero();
|
||||
argument = conv->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
}
|
||||
|
||||
std::size_t flop = ck::utils::conv::get_flops(
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
|
||||
@@ -40,6 +40,8 @@ struct BaseOperator
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
|
||||
@@ -1175,6 +1175,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static size_t GetWorkSpaceSize(const Argument& arg)
|
||||
{
|
||||
size_t WorkSpaceSize = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
WorkSpaceSize =
|
||||
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
|
||||
}
|
||||
}
|
||||
return WorkSpaceSize;
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static size_t GetWorkSpaceSize(const Argument& arg)
|
||||
{
|
||||
size_t WorkSpaceSize = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
|
||||
arg.filter_spatial_lengths_[1] * sizeof(float);
|
||||
}
|
||||
}
|
||||
return WorkSpaceSize;
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static size_t GetWorkSpaceSize(const Argument& arg)
|
||||
{
|
||||
size_t WorkSpaceSize = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
|
||||
arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] *
|
||||
sizeof(float);
|
||||
}
|
||||
}
|
||||
return WorkSpaceSize;
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final
|
||||
{
|
||||
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user