mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-23 06:16:12 +00:00
Tweak GEMM kernel (#38)
* add parameters * tweak gemm * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * clean up
This commit is contained in:
@@ -21,8 +21,8 @@ template <typename... Wei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildaValue,
|
||||
index_t IXTildaValue,
|
||||
typename IYTilda,
|
||||
typename IXTilda,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildaValue>,
|
||||
Number<IXTildaValue>,
|
||||
IYTilda i_ytilda,
|
||||
IXTilda i_xtilda,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilda = Number<IYTildaValue>{};
|
||||
constexpr auto IXTilda = Number<IXTildaValue>{};
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = 1
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = C
|
||||
// GemmK = K
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& /* wei_k_y_x_c_grid_desc */,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
|
||||
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
__host__ __device__ constexpr auto make_right_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const RightPadLength& right_pad,
|
||||
|
||||
@@ -29,7 +29,7 @@ __global__ void
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
@@ -132,7 +132,9 @@ template <index_t BlockSize,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsExtraM,
|
||||
bool BBlockLdsExtraN>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -152,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -171,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
// TODO: turn on this
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
@@ -212,11 +250,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -233,8 +295,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
@@ -245,23 +308,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
#if 1
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#elif 1
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
|
||||
make_tuple(Sequence<1, 0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#endif
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -296,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
|
||||
@@ -90,8 +90,8 @@
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
|
||||
Reference in New Issue
Block a user