mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-14 20:27:42 +00:00
[rocm-libraries] ROCm/rocm-libraries#5387 (commit 0c259bd)
[CK][CK Tile] Grouped Convolution Backward Weight set of fixes (#5387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Grouped Convolution Backward Weight split k fixes for CK tile kernels ## Technical Details - get k batch from kargs to get deduced k batch - multiply zeroing size by data type size - disable v6 (producing a incorrect results) ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result Pass ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
574c1c121a
commit
b8108662da
@@ -126,7 +126,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
|
||||
|
||||
@@ -180,7 +180,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
@@ -56,6 +57,7 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
using Types = ck_tile::builder::factory::internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
|
||||
std::end(kargs.wei_g_k_c_xs_lengths.data),
|
||||
1,
|
||||
@@ -64,10 +66,13 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
auto preprocess = [&]() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.wei_ptr, 0, zeroing_size, s_conf.stream_id_));
|
||||
hipMemsetAsync(kargs.wei_ptr,
|
||||
0,
|
||||
zeroing_size * sizeof(typename Types::EDataType),
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -156,7 +161,7 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
auto preprocess = [&]() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
|
||||
@@ -434,7 +434,9 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
|
||||
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")
|
||||
continue
|
||||
|
||||
if pipeline_version == "V6":
|
||||
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
|
||||
continue
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
|
||||
@@ -73,7 +73,7 @@ void run_cpu_validation(const ckt::Args<SIGNATURE>& args,
|
||||
|
||||
template <auto SIGNATURE>
|
||||
std::tuple<double, double>
|
||||
get_rtol_atol(const int num_accums, const int num_accums_split_k, const float max_accumulated_value)
|
||||
get_rtol_atol(const int num_accums, const int k_batch, const float max_accumulated_value)
|
||||
{
|
||||
using WeiDataType =
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
|
||||
@@ -84,6 +84,8 @@ get_rtol_atol(const int num_accums, const int num_accums_split_k, const float ma
|
||||
using ComputeType = WeiDataType;
|
||||
using AccDataType = float;
|
||||
|
||||
// Assign middle value of the range for auto deduce
|
||||
const int num_accums_split_k = k_batch > 0 ? k_batch : 64;
|
||||
auto rtol = ck_tile::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
auto atol = ck_tile::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
@@ -150,14 +152,17 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
auto run_alg = [&](auto&& run_alg_func) {
|
||||
for(auto& k_batch : split_k_values)
|
||||
{
|
||||
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
|
||||
ckt::Args<SIGNATURE> args_k_batch = args;
|
||||
args_k_batch.k_batch = k_batch;
|
||||
std::tie(is_supported, avg_time, op_name) =
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args,
|
||||
args_k_batch,
|
||||
[&](std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
@@ -182,7 +187,7 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
// Check with cpu verification to get a values
|
||||
run_cpu_validation<SIGNATURE>(args, outputs, reference.get());
|
||||
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
|
||||
}
|
||||
all_instances_valid = false;
|
||||
}
|
||||
|
||||
@@ -136,7 +136,12 @@ int call_profiler(const ckt::Args<SIGNATURE>& args, const std::string& split_k,
|
||||
split_k,
|
||||
inputs.get(),
|
||||
outputs.get(),
|
||||
ck_tile::stream_config{nullptr, time_kernel});
|
||||
ck_tile::stream_config{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
true /*is_gpu_timer_*/});
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "\nBest configuration parameters:" << "\n\tname: " << op_name
|
||||
|
||||
Reference in New Issue
Block a user