Don't miss half of the elements

This commit is contained in:
Graner, Johannes
2025-11-17 12:23:32 +00:00
parent e63bc7a2ce
commit fc86ec44f5
2 changed files with 36 additions and 20 deletions

View File

@@ -615,15 +615,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
input_right_pads,
k_batch_,
split_k_offset_a_hack_,
split_k_offset_b_hack_);
split_k_offset_b_hack_,
true); // use_full_batch_kindex=true for V1-compatible descriptors
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
// With use_full_batch_kindex=true, descriptors contain full k-batch dimension
// so we divide by k_batch_ to get per-batch stride
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
@@ -810,7 +811,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
@@ -842,7 +843,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
@@ -1151,7 +1152,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -1220,7 +1221,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
else
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -1293,7 +1294,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,

View File

@@ -326,7 +326,8 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
const bool split_k_offset_b_hack = false,
const bool use_full_batch_kindex = false)
{
using namespace ck;
@@ -355,9 +356,13 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchIndexA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchIndexB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
@@ -501,7 +506,8 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
const bool split_k_offset_b_hack = false,
const bool use_full_batch_kindex = false)
{
using namespace ck;
@@ -537,9 +543,13 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchIndexA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchIndexB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -691,7 +701,8 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
const bool split_k_offset_b_hack = false,
const bool use_full_batch_kindex = false)
{
using namespace ck;
@@ -734,9 +745,13 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchIndexA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchIndexB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);