mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
fix bug in gridwise gemm xdlops v2r3 (#45)
This commit is contained in:
@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -37,14 +38,14 @@ __global__ void
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
@@ -81,14 +82,14 @@ __global__ void
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -102,7 +103,7 @@ template <index_t BlockSize,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
@@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
@@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
@@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
@@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
@@ -504,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
|
||||
@@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
|
||||
@@ -148,28 +148,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
@@ -181,20 +214,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user