Xdlops refactor fix (#22)

* added constexpr ahead of adptor; clean unused driver; rename M/NPerWave to M/NPerXDL

* fixed bwd

* fixed comment
This commit is contained in:
zjing14
2021-08-23 11:22:10 -05:00
committed by GitHub
parent c6f26bb480
commit 9d3f634a3c
11 changed files with 111 additions and 383 deletions

View File

@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace ck {
@@ -40,7 +41,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
const index_t thread_id = get_thread_local_1d_id();
const auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));

View File

@@ -103,8 +103,8 @@ template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
@@ -183,14 +183,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t
@@ -220,8 +222,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerWave,
NPerWave,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>;
@@ -366,8 +368,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerWave,
NPerWave,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};

View File

@@ -749,8 +749,9 @@ struct XdlopsGemm
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
const auto laneId = GetLaneId();
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),