update some codes

This commit is contained in:
joye
2025-06-03 08:52:00 +08:00
parent 945e3a44ad
commit d24bf107db
3 changed files with 7 additions and 4 deletions

View File

@@ -91,8 +91,8 @@ inline bool parse_cmd_args(int argc,
config.time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
conv_params =
ck::utils::conv::parse_conv_param(num_dim_spatial, num_conv_param_leading_args, argv);
}
else
{

View File

@@ -11,10 +11,10 @@ using CShuffleDataType = FP16;
using DsDataType = ck::Tuple<>;
using InDataType = FP16;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using OutLayout = ck::tensor_layout::convolution::NHWGK;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using InLayout = ck::tensor_layout::convolution::GNHWC;
using InLayout = ck::tensor_layout::convolution::NHWGC;
using OutElementOp = PassThrough;
using WeiElementOp = PassThrough;

View File

@@ -90,6 +90,9 @@ bool run_conv_bwd_data(const ExecutionConfig& config,
wei_element_op,
in_element_op);
DeviceMem workspace_buf(argument.GetWorkspaceSizeBytes());
conv.SetWorkspacePointer(&argument, workspace_buf.GetDeviceBuffer());
if(!conv.IsSupportedArgument(argument))
{
std::cerr << "wrong! device_conv with the specified compilation parameters does "