mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Tweak GEMM kernel (#38)
* add parameters
* tweak gemm
* tweak
* update conv
* update script
* adding bwd 1x1
* update script
* adding 1x1 bwd
* debugging bwd 1x1 failure
* update script
* update script
* test
* test v100
* clean up
[ROCm/composable_kernel commit: b3e8d57d51]
This commit is contained in:
@@ -21,8 +21,8 @@ template <typename... Wei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildaValue,
|
||||
index_t IXTildaValue,
|
||||
typename IYTilda,
|
||||
typename IXTilda,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildaValue>,
|
||||
Number<IXTildaValue>,
|
||||
IYTilda i_ytilda,
|
||||
IXTilda i_xtilda,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilda = Number<IYTildaValue>{};
|
||||
constexpr auto IXTilda = Number<IXTildaValue>{};
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = 1
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = C
|
||||
// GemmK = K
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& /* wei_k_y_x_c_grid_desc */,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
|
||||
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
__host__ __device__ constexpr auto make_right_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const RightPadLength& right_pad,
|
||||
|
||||
@@ -29,7 +29,7 @@ __global__ void
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
@@ -132,7 +132,9 @@ template <index_t BlockSize,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsExtraM,
|
||||
bool BBlockLdsExtraN>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -152,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -171,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
// TODO: turn on this
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
@@ -212,11 +250,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -233,8 +295,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
@@ -245,23 +308,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
#if 1
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#elif 1
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
|
||||
make_tuple(Sequence<1, 0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#endif
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -296,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
|
||||
@@ -90,8 +90,8 @@
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
|
||||
13
host/driver_offline/include/debug.hpp
Normal file
13
host/driver_offline/include/debug.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef DEBUG_HPP
|
||||
#define DEBUG_HPP
|
||||
|
||||
namespace debug {
|
||||
namespace debug_driver_gemm_xdlops_v2r3 {
|
||||
|
||||
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
|
||||
static ck::index_t M01 = 1;
|
||||
static ck::index_t N01 = 1;
|
||||
|
||||
} // namespace debug_driver_gemm_xdlops_v2r3
|
||||
} // namespace debug
|
||||
#endif
|
||||
@@ -48,8 +48,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -76,7 +76,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -105,7 +105,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -159,34 +159,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
@@ -294,13 +266,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug_driver_gemm_xdlops_v2r3::N01,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
|
||||
@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -159,25 +159,93 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-:
|
||||
// gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
@@ -215,7 +284,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
//clang-format on
|
||||
// clang-format on
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
@@ -225,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
|
||||
{
|
||||
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
|
||||
{
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
i_ytilda,
|
||||
i_xtilda,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0);
|
||||
|
||||
if(GemmK0 != 0)
|
||||
{
|
||||
ave_time += driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations&,
|
||||
const InLeftPads&,
|
||||
const InRightPads&,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
// clang-format on
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -203,18 +203,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
|
||||
@@ -49,7 +49,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -77,7 +77,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -133,7 +133,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -160,8 +160,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -189,7 +189,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
@@ -316,13 +372,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
|
||||
@@ -4,16 +4,8 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
const BDesc& b_k_n_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_k_m,
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k_m_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
|
||||
263
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
Normal file
263
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -4,16 +4,8 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
const BDesc& b_n_k_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_k_m,
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_n_k_grid_desc.GetLength(I0);
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k_m_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_n_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
|
||||
263
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
Normal file
263
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -4,16 +4,8 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
const BDesc& b_k_n_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_m_k,
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
@@ -33,8 +22,148 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
@@ -88,46 +217,157 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k_grid_desc.GetLength(I1);
|
||||
const auto M = a_m_k_grid_desc.GetLength(I0);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_m_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(M),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
@@ -147,9 +387,9 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
@@ -194,13 +434,17 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
|
||||
291
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
Normal file
291
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
Normal file
@@ -0,0 +1,291 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -4,16 +4,8 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
const BDesc& b_n_k_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_m_k,
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
@@ -34,6 +23,34 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -60,9 +77,93 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
@@ -90,7 +191,7 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
@@ -117,8 +218,36 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
@@ -144,46 +273,131 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k_grid_desc.GetLength(I1);
|
||||
const auto M = a_m_k_grid_desc.GetLength(I0);
|
||||
const auto N = b_n_k_grid_desc.GetLength(I0);
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
#if 1
|
||||
// non-padded GEMM
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_m_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(M),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_n_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
@@ -203,9 +417,80 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
#else
|
||||
// padded GEMM
|
||||
const auto a_k0_m_k1_grid_desc_tmp =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_right_pad_transform(M, MRightPad),
|
||||
make_pass_through_transform(K1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc_tmp,
|
||||
make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
#endif
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
@@ -250,13 +535,17 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
|
||||
347
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
Normal file
347
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
Normal file
@@ -0,0 +1,347 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
@@ -46,13 +46,17 @@ template <ck::index_t BlockSize,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
@@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat>;
|
||||
CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
@@ -123,7 +129,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
@@ -134,7 +141,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
const auto c_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -14,15 +15,16 @@
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp"
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
|
||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||
|
||||
enum ConvBackwardDataAlgo
|
||||
{
|
||||
V4R1XDLNHWC,
|
||||
V4R1R2XDLNHWC,
|
||||
V4R1XDLNHWC, // 0
|
||||
V4R1R2XDLNHWC, // 1
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -280,20 +282,43 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 &&
|
||||
in_right_pad_w == 0)
|
||||
{
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -24,7 +25,7 @@
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -16,11 +17,19 @@
|
||||
#include "device_gemm_xdlops_mk_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_mk_kn_nm.hpp"
|
||||
#include "device_gemm_xdlops_mk_nk_nm.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_nm.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_nm.hpp"
|
||||
|
||||
#define USE_GEMM_XDL_MK_KN_MN 1
|
||||
#define USE_GEMM_XDL_MK_NK_MN 1
|
||||
#define USE_GEMM_XDL_KM_KN_MN 1
|
||||
#define USE_GEMM_XDL_KM_NK_MN 1
|
||||
#define USE_GEMM_XDL_MK_KN_NM 0
|
||||
#define USE_GEMM_XDL_MK_NK_NM 0
|
||||
#define USE_GEMM_XDL_KM_KN_NM 0
|
||||
#define USE_GEMM_XDL_KM_NK_NM 0
|
||||
|
||||
enum GemmAlgo
|
||||
{
|
||||
@@ -28,21 +37,21 @@ enum GemmAlgo
|
||||
Xdl_MK_NK_MN, // 1
|
||||
Xdl_KM_KN_MN, // 2
|
||||
Xdl_KM_NK_MN, // 3
|
||||
Xdl_MK_KN_NM, // 4
|
||||
Xdl_MK_NK_NM, // 5
|
||||
Xdl_KM_KN_NM, // 6
|
||||
Xdl_KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
// dynamic mode
|
||||
if(argc != 10)
|
||||
if(argc != 12)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: M, N, K\n");
|
||||
printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -57,6 +66,9 @@ int main(int argc, char* argv[])
|
||||
const index_t N = std::stoi(argv[8]);
|
||||
const index_t K = std::stoi(argv[9]);
|
||||
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);
|
||||
|
||||
#if 0
|
||||
using ab_data_t = float;
|
||||
using acc_data_t = float;
|
||||
@@ -74,69 +86,44 @@ int main(int argc, char* argv[])
|
||||
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
|
||||
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
|
||||
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
// A
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN ||
|
||||
layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
a_strides_host[0] = static_cast<std::size_t>(K);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
a_strides_host[0] = static_cast<std::size_t>(K);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
b_strides_host[0] = static_cast<std::size_t>(K);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
else
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
a_strides_host[0] = static_cast<std::size_t>(M);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
a_strides_host[0] = static_cast<std::size_t>(M);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
// B
|
||||
if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN ||
|
||||
layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
b_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
b_strides_host[0] = static_cast<std::size_t>(K);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
// C
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN ||
|
||||
layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
@@ -144,7 +131,10 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
c_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
c_strides_host[0] = static_cast<std::size_t>(M);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
|
||||
@@ -185,38 +175,6 @@ int main(int argc, char* argv[])
|
||||
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_mk_kn_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
auto f_make_for_device_mk_nk_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
auto f_make_for_device_km_kn_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
auto f_make_for_device_km_nk_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_MN)
|
||||
{
|
||||
@@ -225,10 +183,7 @@ int main(int argc, char* argv[])
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_mk_kn_mn();
|
||||
|
||||
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -240,10 +195,7 @@ int main(int argc, char* argv[])
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_mk_nk_mn();
|
||||
|
||||
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -255,10 +207,7 @@ int main(int argc, char* argv[])
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_km_kn_mn();
|
||||
|
||||
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -270,10 +219,55 @@ int main(int argc, char* argv[])
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_km_nk_mn();
|
||||
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
#if USE_GEMM_XDL_MK_KN_NM
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_NK_NM
|
||||
if(algo == GemmAlgo::Xdl_MK_NK_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_KN_NM
|
||||
if(algo == GemmAlgo::Xdl_KM_KN_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_NK_NM
|
||||
if(algo == GemmAlgo::Xdl_KM_NK_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#define DEVICE_HPP
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
@@ -74,6 +76,8 @@ float launch_and_time_kernel(
|
||||
|
||||
timer.End();
|
||||
|
||||
// std::this_thread::sleep_for (std::chrono::microseconds(10));
|
||||
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,10 @@ enum GemmMatrixLayout
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
MK_KN_NM, // 4
|
||||
MK_NK_NM, // 5
|
||||
KM_KN_NM, // 6
|
||||
KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -80,6 +80,78 @@ void host_gemm(const Tensor<AType>& a,
|
||||
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
auto f_mk_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
auto f_mk_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
auto f_km_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
auto f_km_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
|
||||
14
script/docker-rocm4.3.1.sh
Executable file
14
script/docker-rocm4.3.1.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
WORKSPACE=$1
|
||||
echo "workspace: " $WORKSPACE
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v $WORKSPACE:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
|
||||
#--network host \
|
||||
151
script/run.sh
151
script/run.sh
@@ -4,24 +4,12 @@
|
||||
export ROCR_VISIBLE_DEVICE=0
|
||||
export GPU_DEVICE_ORDINAL=0
|
||||
|
||||
## Boost
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
|
||||
|
||||
## Compiling
|
||||
#export OLC_DEBUG_HIP_VERBOSE=1
|
||||
#export OLC_DEBUG_HIP_DUMP=1
|
||||
#export OLC_DEBUG_SAVE_TEMP_DIR=1
|
||||
|
||||
#rm -rf /root/_hip_binary_kernels_/
|
||||
#rm -rf /tmp/olCompile*
|
||||
|
||||
#make -j conv_fwd_driver_offline
|
||||
make -j conv_fwd_driver_offline
|
||||
#make -j conv_bwd_driver_offline
|
||||
#make -j conv_wrw_driver_offline
|
||||
#make -j conv_fwd_driver_online
|
||||
|
||||
make -j gemm_driver_offline
|
||||
#make -j gemm_driver_offline
|
||||
|
||||
DRIVER="./host/driver_offline/conv_fwd_driver_offline"
|
||||
LAYOUT=$1
|
||||
ALGO=$2
|
||||
VERIFY=$3
|
||||
@@ -29,30 +17,121 @@ INIT=$4
|
||||
LOG=$5
|
||||
REPEAT=$6
|
||||
|
||||
################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
#M01=$7
|
||||
#N01=$8
|
||||
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
KBATCH=$7
|
||||
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
|
||||
#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
|
||||
#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
|
||||
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
################################################ layout algo verify init log repeat M___ N___ K___
|
||||
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
|
||||
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048
|
||||
./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096
|
||||
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
######### layout algo verify init log repeat M___ N___ K___
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
|
||||
|
||||
# Resnet50
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1
|
||||
|
||||
# 256x128x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
|
||||
|
||||
|
||||
|
||||
# 128x128x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
|
||||
|
||||
|
||||
# 128x64x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
|
||||
# 64x128x32 c64
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
|
||||
# 64x64x32 c32
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448
|
||||
|
||||
Reference in New Issue
Block a user