mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[Bug Fix] GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 loop issue (#44)
* change method computering kpad
* remove unusing variable: batchlen
* change KPerBlock to K0PerBlock
* fix bug for k0 == k0perblock
* fix bug for get k0 index
* use math::integer_divide_ceil
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 6014185ac6]
This commit is contained in:
@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -37,14 +38,14 @@ __global__ void
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
@@ -53,7 +54,8 @@ template <typename GridwiseGemm,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -81,14 +83,14 @@ __global__ void
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -102,7 +104,7 @@ template <index_t BlockSize,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
@@ -158,13 +160,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -173,13 +175,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -220,7 +222,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
@@ -248,6 +250,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
@@ -258,13 +267,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -273,13 +282,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -338,6 +347,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -376,13 +386,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -390,8 +400,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<KPerBlock>{} * Number<MPerBlock + 1>{} * K1,
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
|
||||
Number<MPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
@@ -399,7 +409,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
@@ -408,13 +418,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -422,8 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<KPerBlock>{} * Number<NPerBlock + 1>{} * K1,
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
|
||||
Number<NPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
@@ -431,7 +441,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
@@ -439,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, KPerBlock, MPerBlock, K1>,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -466,7 +476,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, KPerBlock, NPerBlock, K1>,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -491,8 +501,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
@@ -518,8 +528,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
@@ -546,31 +556,35 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
do
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
a_blockwise_copy.RunRead(
|
||||
a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
k_block_data_begin += K0PerBlock;
|
||||
} while(k_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
|
||||
@@ -95,13 +95,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GemmK = GemmKTotal / GemmK1;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
|
||||
const index_t GemmK0 = BatchLen * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
|
||||
@@ -123,13 +123,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
|
||||
const auto GemmN = K;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GemmK = GemmKTotal / GemmK1;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
|
||||
const index_t GemmK0 = BatchLen * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
|
||||
@@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -291,13 +291,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GemmK = GemmKTotal / GemmK1;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
|
||||
const index_t GemmK0 = BatchLen * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
|
||||
@@ -156,27 +156,58 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
|
||||
std::cout << "gridSize : " << grid_size << std::endl;
|
||||
}
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc));
|
||||
@@ -189,20 +220,58 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user