[CK TILE GEMM] Fix a merge conflict (#2753)

* Fixed a merge conflict in 245467f3
* Foramt the code
This commit is contained in:
Cong Ma
2025-08-27 12:08:09 -06:00
committed by GitHub
parent cfe5e448db
commit cd53e2e57e
3 changed files with 41 additions and 43 deletions

View File

@@ -40,8 +40,8 @@ template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);

View File

@@ -11,17 +11,17 @@ template <ck_tile::index_t NDimSpatial,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_warmup,
int n_repeat)
int n_warmup,
int n_repeat)
{
float ave_time = grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
@@ -124,11 +124,11 @@ int run_grouped_conv_bwd_data_example_with_layouts(
output_dev_buf.ToDevice(output.data());
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
@@ -136,13 +136,13 @@ int run_grouped_conv_bwd_data_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
input_dev_buf.FromDevice(input.data());
bool pass = true;
@@ -152,17 +152,15 @@ int run_grouped_conv_bwd_data_example_with_layouts(
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
input_host_ref.SetZero();
ck_tile::
reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK =
weight.get_element_size() / (conv_param.G_ * conv_param.K_);
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
const auto rtol_atol =