mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Optimize grouped conv bwd weight for small M and N (#1303)
* Optimize grouped conv bwd weight for small M and N * Fixes
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -603,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<MPerBlock / MLdsLayer>{},
|
||||
Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -669,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(Number<mpair>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
@@ -740,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<NPerBlock / NLdsLayer>{},
|
||||
Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -803,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(Number<npair>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -781,8 +781,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<MPerBlock / MLdsLayer>{},
|
||||
Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -847,7 +847,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(Number<mpair>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
@@ -918,8 +918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<NPerBlock / NLdsLayer>{},
|
||||
Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -981,7 +981,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(Number<npair>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
|
||||
Reference in New Issue
Block a user