Index -> Dim

This commit is contained in:
Graner, Johannes
2025-11-26 10:14:40 +00:00
parent 5d1d298e2b
commit c168426885
2 changed files with 33 additions and 33 deletions

View File

@@ -175,9 +175,9 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
@@ -195,7 +195,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -213,7 +213,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -251,7 +251,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -290,7 +290,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -366,9 +366,9 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -387,7 +387,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -402,7 +402,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -431,7 +431,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -472,7 +472,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -551,9 +551,9 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -572,7 +572,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -587,7 +587,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -616,7 +616,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -666,7 +666,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));

View File

@@ -359,9 +359,9 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchIndexA =
const index_t KBatchDimA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchIndexB =
const index_t KBatchDimB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
@@ -382,7 +382,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -398,7 +398,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -428,7 +428,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -469,7 +469,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -748,9 +748,9 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchIndexA =
const index_t KBatchDimA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchIndexB =
const index_t KBatchDimB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
@@ -771,7 +771,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -787,7 +787,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -817,7 +817,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -882,7 +882,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));