[CK][CK Tile] Add grouped conv backward weight tile test and fix tr load in BASE_V1 pipeline (#5115)

## Motivation

Test grouped conv backward weight from ck tile and fix incorrect values.

## Technical Details

- Add test for CI
- Add daily tests
- Fix transpose load in BASE_V1 pipeline

## Test Plan

test_grouped_convnd_backward_weight_tile

## Test Result

in progress

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-783
This commit is contained in:
Bartłomiej Kocot
2026-03-10 04:03:04 +01:00
committed by GitHub
parent 751e29ccb6
commit c749ef9da3
5 changed files with 312 additions and 45 deletions

3
Jenkinsfile vendored
View File

@@ -1524,7 +1524,8 @@ pipeline {
setup_args = "NO_CK_BUILD"
execute_args = """ python3 ../experimental/grouped_convolution_tile_instances/generate_instances.py --mode=profiler && \
cmake .. --preset dev-gfx90a -D CK_EXPERIMENTAL_BUILDER=ON && \
make -j64 test_grouped_convnd_fwd_tile && \
make -j64 test_grouped_convnd_fwd_tile test_grouped_convnd_bwd_weight_tile && \
./bin/test_grouped_convnd_bwd_weight_tile && \
./bin/test_grouped_convnd_fwd_tile"""
}
steps{

View File

@@ -288,6 +288,8 @@ def parse_bwd_weight_instances(instances, problem_name):
end = instance.rindex('>')
params_str = instance[start:end]
args = parse_instance_string(params_str)
direct_load = False
is_v3_instance = instance.find("Xdl_CShuffleV3") != -1
is_two_stage_instance = instance.find("TwoStage") != -1
@@ -357,6 +359,7 @@ def parse_bwd_weight_instances(instances, problem_name):
if len(args) != 45:
raise RuntimeError(f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}")
direct_load = int(args[43]) == 1
num_groups_to_merge = int(args[44])
# Block GEMM pipeline parameters
@@ -396,6 +399,16 @@ def parse_bwd_weight_instances(instances, problem_name):
if pipeline_version == "V5":
pipeline_version = "V6"
if direct_load:
if pipeline_version == "V1":
pipeline_version = "ASYNC_V1"
elif pipeline_version == "V4":
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError(
f"Not supported pipeline for direct load: pipeline_version={pipeline_version} in instance: {instance}"
)
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
warp_size = 64

View File

@@ -112,6 +112,9 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kLdsAlignmentInBytes = 16;
@@ -272,29 +275,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
// LDS write 0
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
}
@@ -310,7 +311,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
@@ -322,29 +324,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// LDS write i + 1
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
// LDS write i + 1
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
@@ -354,7 +354,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
// tail
{
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
@@ -479,29 +480,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
// LDS write 0
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
}
@@ -518,36 +517,38 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
// LDS write i + 1
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
@@ -558,7 +559,11 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
}
return c_block_tile;

View File

@@ -18,6 +18,17 @@ elseif(DL_KERNELS)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9")
if(CK_EXPERIMENTAL_BUILDER)
add_gtest_executable(test_grouped_convnd_bwd_weight_tile test_grouped_convnd_bwd_weight_tile.cpp)
target_compile_options(test_grouped_convnd_bwd_weight_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat)
target_link_libraries(test_grouped_convnd_bwd_weight_tile PRIVATE gtest_main getopt::getopt utility)
if(TARGET device_grouped_conv_bwd_weight_tile_instances)
target_link_libraries(test_grouped_convnd_bwd_weight_tile PRIVATE device_grouped_conv_bwd_weight_tile_instances)
endif()
endif()
endif()
add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility)

View File

@@ -0,0 +1,237 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "profiler/grouped_convolution_backward_weight_tile_algs.hpp"
static ck::index_t args_mask = 0xffff;
static ck::index_t instance_index = -1;
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace ckp = ck_tile::builder::profiling;
template <ck_tile::index_t num_spatial_dim_,
ckb::DataType data_type_,
ckb::DataType acc_data_type_,
ckb::TensorLayout in_layout_,
ckb::TensorLayout wei_layout_,
ckb::TensorLayout out_layout_>
struct SignatureDetails
{
static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_;
static constexpr ckb::DataType data_type = data_type_;
static constexpr ckb::DataType acc_data_type = acc_data_type_;
static constexpr ckb::TensorLayout in_layout = in_layout_;
static constexpr ckb::TensorLayout wei_layout = wei_layout_;
static constexpr ckb::TensorLayout out_layout = out_layout_;
};
template <typename SignatureDetailsType>
class TestGroupedConvndBwdWeightTile : public ::testing::Test
{
protected:
static constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = SignatureDetailsType::data_type,
.accumulation_data_type = SignatureDetailsType::acc_data_type,
.input = {.config = {.layout = SignatureDetailsType::in_layout}},
.weight = {.config = {.layout = SignatureDetailsType::wei_layout}},
.output = {.config = {.layout = SignatureDetailsType::out_layout}}};
std::vector<ckt::Args<SIGNATURE>> conv_args;
std::vector<std::string> split_ks{"-1", "1", "2"};
template <ck::index_t NDimSpatial>
void Run()
{
ASSERT_FALSE(conv_args.empty());
bool pass = true;
for(size_t i = 0; i < conv_args.size(); i++)
{
for(auto& split_k : split_ks)
{
if((args_mask & (1 << i)) == 0)
{
continue;
}
auto& args = conv_args[i];
auto inputs = alloc_inputs(args);
auto outputs = alloc_outputs(args);
ckt::init_tensor_buffer_uniform_int(
inputs.get().input, args.make_input_descriptor(), -5, 5);
ckt::init_tensor_buffer_uniform_int(
inputs.get().output, args.make_output_descriptor(), -5, 5);
std::cout << args.make_input_descriptor() << std::endl;
std::cout << args.make_weight_descriptor() << std::endl;
std::cout << args.make_output_descriptor() << std::endl;
[[maybe_unused]] auto&& [case_passed, avg_time, op_name, best_split_k] =
ckp::run_grouped_conv_backward_weight_tile_algs(
args,
split_k,
inputs.get(),
outputs.get(),
ck_tile::stream_config{nullptr, false /*time_kernel*/});
pass = pass && case_passed;
}
}
EXPECT_TRUE(pass);
}
void conv_args_append(std::size_t,
std::size_t G,
std::size_t N,
std::size_t K,
std::size_t C,
const std::vector<std::size_t>& filter_spatial_lengths,
const std::vector<std::size_t>& input_spatial_lengths,
const std::vector<std::size_t>& conv_filter_strides,
const std::vector<std::size_t>& conv_filter_dilations,
const std::vector<std::size_t>& input_left_pads,
const std::vector<std::size_t>& input_right_pads)
{
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = N,
.groups = G,
.input_channels = C,
.output_channels = K,
.image = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
input_spatial_lengths),
.filter = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
filter_spatial_lengths),
},
.filter_strides = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
conv_filter_strides),
.filter_dilation =
ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
conv_filter_dilations),
.input_left_pad = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
input_left_pads),
.input_right_pad =
ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
input_right_pads),
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
conv_args.push_back(args);
}
};
using KernelTypes2d = ::testing::Types<SignatureDetails<2,
ckb::DataType::FP32,
ckb::DataType::FP32,
ckb::TensorLayout::NHWGC,
ckb::TensorLayout::GKYXC,
ckb::TensorLayout::NHWGK>,
SignatureDetails<2,
ckb::DataType::FP16,
ckb::DataType::FP32,
ckb::TensorLayout::NHWGC,
ckb::TensorLayout::GKYXC,
ckb::TensorLayout::NHWGK>,
SignatureDetails<2,
ckb::DataType::BF16,
ckb::DataType::FP32,
ckb::TensorLayout::NHWGC,
ckb::TensorLayout::GKYXC,
ckb::TensorLayout::NHWGK>>;
using KernelTypes3d = ::testing::Types<SignatureDetails<3,
ckb::DataType::FP32,
ckb::DataType::FP32,
ckb::TensorLayout::NDHWGC,
ckb::TensorLayout::GKZYXC,
ckb::TensorLayout::NDHWGK>,
SignatureDetails<3,
ckb::DataType::FP16,
ckb::DataType::FP32,
ckb::TensorLayout::NDHWGC,
ckb::TensorLayout::GKZYXC,
ckb::TensorLayout::NDHWGK>,
SignatureDetails<3,
ckb::DataType::BF16,
ckb::DataType::FP32,
ckb::TensorLayout::NDHWGC,
ckb::TensorLayout::GKZYXC,
ckb::TensorLayout::NDHWGK>>;
template <typename SignatureDetailsType>
class TestGroupedConvndBwdWeightTile2d : public TestGroupedConvndBwdWeightTile<SignatureDetailsType>
{
};
template <typename SignatureDetailsType>
class TestGroupedConvndBwdWeightTile3d : public TestGroupedConvndBwdWeightTile<SignatureDetailsType>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightTile2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightTile3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdWeightTile2d, Test2D)
{
this->conv_args.clear();
this->conv_args_append(2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
this->conv_args_append(2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0});
this->conv_args_append(2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
this->conv_args_append(2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
this->conv_args_append(2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
this->conv_args_append(2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
this->conv_args_append(2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
this->conv_args_append(2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
this->conv_args_append(2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
this->conv_args_append(2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndBwdWeightTile3d, Test3D)
{
this->conv_args.clear();
this->conv_args_append(
3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
this->conv_args_append(
3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
this->conv_args_append(
3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
this->conv_args_append(
3, 1, 1, 1, 32, {3, 3, 3}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
this->conv_args_append(
3, 1, 1, 64, 3, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
this->conv_args_append(
3, 1, 1, 1, 1, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
this->conv_args_append(
3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
this->template Run<3>();
}
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
args_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}