mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Fix transform and instances for grouped conv bwd data (#848)
* Fix transform and instances for grouped conv bwd data
* Add instances for small K and small C
* Remove workaround after fix
* Fix interface tests
[ROCm/composable_kernel commit: 595d23be14]
This commit is contained in:
@@ -200,9 +200,6 @@
|
||||
// workaround: compiler issue on gfx908
|
||||
#define CK_WORKAROUND_SWDEV_388832 1
|
||||
|
||||
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
|
||||
#define CK_WORKAROUND_GITHUB_ISSUE_824 1
|
||||
|
||||
// flag to enable (1) or disable (0) the debugging output in some kernels
|
||||
#define DEBUG_LOG 0
|
||||
|
||||
|
||||
@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto AK = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto BK = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
// check consistency of desc
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// check tile size
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
const auto num_k_loop = AK / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
|
||||
@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
|
||||
@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
|
||||
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
|
||||
|
||||
const index_t AK0 = K / AK1;
|
||||
|
||||
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
|
||||
const auto out_grid_desc =
|
||||
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
|
||||
@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t AK0 = math::integer_divide_ceil(K, AK1);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
// A: output tensor
|
||||
@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
Sequence<5>{}));
|
||||
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
const auto out_gemmk_gemmm_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
out_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(AK1, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
}
|
||||
@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto
|
||||
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
|
||||
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7, 8>{}));
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
|
||||
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
|
||||
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, AK0)),
|
||||
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
|
||||
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
|
||||
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
const auto out_gemmk_gemmm_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
out_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(AK1, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
}
|
||||
@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
|
||||
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
|
||||
|
||||
const index_t BK0 = K / BK1;
|
||||
|
||||
// assume packed
|
||||
// k_y_x_c for 2d or k_z_y_x_c for 3d
|
||||
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
|
||||
@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t BK0 = math::integer_divide_ceil(K, BK1);
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
|
||||
|
||||
// B weight tensor
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
@@ -566,43 +586,47 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
const auto wei_gemmk_gemmn_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
|
||||
GemmNPerBlock,
|
||||
BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
wei_gemmk_gemmnraw_grid_desc,
|
||||
make_tuple(BK1, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(
|
||||
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
}
|
||||
@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc =
|
||||
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_slice_transform(ZDot, I0, ZDotSlice),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1
|
||||
Sequence<4>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<5>{}));
|
||||
Sequence<4>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, BK0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
const auto wei_gemmk_gemm_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
|
||||
GemmNPerBlock,
|
||||
BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
wei_gemmk_gemmnraw_grid_desc,
|
||||
make_tuple(BK1, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemm_padded_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return wei_gemmbk0_gemm_gemmbk1_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user