From c16842688518ff140e4fc1247087b470eb160f23 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Wed, 26 Nov 2025 10:14:40 +0000 Subject: [PATCH] Index -> Dim --- .../transform_conv_bwd_weight_to_gemm.hpp | 42 +++++++++---------- .../transform_conv_bwd_weight_to_gemm_v2.hpp | 24 +++++------ 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index e4e2a8bbfc..54aec0ca6f 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -175,9 +175,9 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; - const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch; if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) @@ -195,7 +195,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -213,7 +213,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -251,7 +251,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -290,7 +290,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -366,9 +366,9 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; - const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -387,7 +387,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -402,7 +402,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -431,7 +431,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -472,7 +472,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -551,9 +551,9 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; - const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -572,7 +572,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -587,7 +587,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -616,7 +616,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -666,7 +666,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index e1576ec27d..9864ae2698 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -359,9 +359,9 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise // kernel compatibility - const index_t KBatchIndexA = + const index_t KBatchDimA = (split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; - const index_t KBatchIndexB = + const index_t KBatchDimB = (split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); @@ -382,7 +382,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -398,7 +398,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -428,7 +428,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -469,7 +469,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -748,9 +748,9 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise // kernel compatibility - const index_t KBatchIndexA = + const index_t KBatchDimA = (split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; - const index_t KBatchIndexB = + const index_t KBatchDimB = (split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); @@ -771,7 +771,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -787,7 +787,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -817,7 +817,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -882,7 +882,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}));