mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Index -> Dim
This commit is contained in:
@@ -175,9 +175,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -195,7 +195,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -213,7 +213,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -251,7 +251,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -290,7 +290,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -366,9 +366,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
@@ -387,7 +387,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -402,7 +402,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -431,7 +431,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -472,7 +472,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -551,9 +551,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
@@ -572,7 +572,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -587,7 +587,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -616,7 +616,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -666,7 +666,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -359,9 +359,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchIndexA =
|
||||
const index_t KBatchDimA =
|
||||
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB =
|
||||
const index_t KBatchDimB =
|
||||
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
|
||||
@@ -382,7 +382,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -398,7 +398,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -428,7 +428,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -469,7 +469,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -748,9 +748,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchIndexA =
|
||||
const index_t KBatchDimA =
|
||||
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB =
|
||||
const index_t KBatchDimB =
|
||||
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
@@ -771,7 +771,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -787,7 +787,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -817,7 +817,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -882,7 +882,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
Reference in New Issue
Block a user