mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Don't miss half of the elements
This commit is contained in:
@@ -615,15 +615,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
split_k_offset_a_hack_,
|
||||
split_k_offset_b_hack_);
|
||||
split_k_offset_b_hack_,
|
||||
true); // use_full_batch_kindex=true for V1-compatible descriptors
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
|
||||
// With use_full_batch_kindex=true, descriptors contain full k-batch dimension
|
||||
// so we divide by k_batch_ to get per-batch stride
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
@@ -810,7 +811,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
@@ -842,7 +843,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
@@ -1151,7 +1152,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -1220,7 +1221,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -1293,7 +1294,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -326,7 +326,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
const bool split_k_offset_b_hack = false,
|
||||
const bool use_full_batch_kindex = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -355,9 +356,13 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchIndexA =
|
||||
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB =
|
||||
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
|
||||
@@ -501,7 +506,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
const bool split_k_offset_b_hack = false,
|
||||
const bool use_full_batch_kindex = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -537,9 +543,13 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchIndexA =
|
||||
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB =
|
||||
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
@@ -691,7 +701,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
const bool split_k_offset_b_hack = false,
|
||||
const bool use_full_batch_kindex = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -734,9 +745,13 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchIndexA =
|
||||
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB =
|
||||
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
|
||||
Reference in New Issue
Block a user