mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[conv bwd-weight]Binding gemm k1 to conv n (#202)
* add some instance to develop * avoid bank conflicts for wrw for all instance * add small K1 test * delete some unused instance * binding gemm k1 to conv n * try using half_4 to do ds_read * reset buffer load oob and ds memcpy to default option * remove useless instances * remove redandunt space * remove printf code * clang-format-10 change * use fastest config * fix clang format for the other files * remove gemmk0 pad for output * add gemmk padding macro * add bank length computation * add template to distinguish the instance that need lds padding for wrw * use rocm5.1 as docker * use integer value for GEMM test * add Right padding macro * add 2 test asm code * using 256x256x32 tile size * 1. move dedicated transform into gridwisegemm's head file. 2. make lds tensor params a struct templete. 3. remove useless code * using small vec * 256*128 kernel size for example * remove asm files * use a new gridwise gemm header for bwd-weight * revert gridwise gemm v2r4r2 * change foramt * reset gridwise gemm v2r4r2 * remove unused code * revert instance file * revert example instance * format file * remove macros * resolve compile error * rename wrw kernel invoker * use gridwisegemm pipeline struct instead of implement run fucntion in the same header Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -235,8 +236,9 @@ template <index_t BlockSize,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool ABlockLdsExtraM1Wrw = false,
|
||||
bool BBlockLdsExtraN1Wrw = false>
|
||||
bool ABlockLdsExtraM1Wrw = false,
|
||||
bool BBlockLdsExtraN1Wrw = false,
|
||||
index_t NumGemmKPrefetchStage = 1>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -251,7 +253,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
|
||||
// M0/M1/M1Padding
|
||||
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
|
||||
@@ -511,6 +514,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
|
||||
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
|
||||
@@ -548,9 +559,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
// const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
const index_t num_loop = K0 / K0PerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
|
||||
// return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -771,51 +785,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user