mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
[Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations (#190)
* add some instance to develop
* avoid bank conflicts for wrw for all instance
* add small K1 test
* delete some unused instance
* reset buffer load oob and ds memcpy to default option
* remove useless instances
* remove redandunt space
* remove printf code
* clang-format-10 change
* fix clang format for the other files
* add bank length computation
* add template to distinguish the instance that need lds padding for wrw
* use rocm5.1 as docker
* use integer value for GEMM test
* 1. move dedicated transform into gridwisegemm's head file. 2. make lds tensor params a struct templete. 3. remove useless code
* use a new gridwise gemm header for bwd-weight
* revert gridwise gemm v2r4r2
* change foramt
* rename kernel invoker
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: b9b9c3b814]
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
|
||||
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -81,6 +81,20 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
// Bytes per 32 lds bank: 32 * 4 bytes
|
||||
static constexpr auto BankLength = 128;
|
||||
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
|
||||
|
||||
// M1 & M0
|
||||
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
|
||||
static constexpr auto ABlockLdsM1Padding = 4;
|
||||
|
||||
// N1 & N0
|
||||
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
|
||||
static constexpr auto BBlockLdsN1Padding = 4;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
@@ -205,7 +219,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
@@ -233,6 +247,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -241,12 +258,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
@@ -274,6 +296,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -282,10 +307,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||
@@ -465,7 +495,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -482,7 +512,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -502,7 +532,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -519,7 +549,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user