Regression fix for cshuffle

This commit is contained in:
Graner, Johannes
2025-12-05 14:05:27 +00:00
parent 56caf529f8
commit bb6a3571a2
2 changed files with 164 additions and 50 deletions

View File

@@ -33,6 +33,111 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Dispatch helper function for split-K hack - handles 4-way dispatch based on runtime flags
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid,
const FloatB* p_b_grid,
FloatC* p_c_grid,
void* p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack,
index_t k_batch)
{
if(split_k_offset_a_hack && split_k_offset_b_hack)
{
GridwiseGemm::template Run<HasMainKBlockLoop, true, true>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
else if(split_k_offset_a_hack)
{
GridwiseGemm::template Run<HasMainKBlockLoop, true, false>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
else if(split_k_offset_b_hack)
{
GridwiseGemm::template Run<HasMainKBlockLoop, false, true>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, false, false>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
}
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
@@ -84,22 +189,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
split_k_offset_a_hack,
split_k_offset_b_hack,
k_batch);
DispatchBatchedGemmSplitKHack<GridwiseGemm,
FloatA,
FloatB,
FloatC,
AGridDesc_B_K0_M_K1,
BGridDesc_B_K0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
Block2CTileMap,
HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
split_k_offset_a_hack,
split_k_offset_b_hack,
k_batch);
}
#else
ignore = p_a_grid;
@@ -474,7 +591,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<>,
false>, // Both true/false give the same occupancy.
false>, // HasMainKBlockLoop - both true/false give the same occupancy
BlockSize,
dynamic_smem_size));
return std::max(1, max_occupancy);

View File

@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
@@ -149,7 +150,9 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool SplitKOffsetAHack,
bool SplitKOffsetBHack>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -167,8 +170,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CBlockClusterAdaptor c_block_cluster_adaptor,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack,
index_t k_batch)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
@@ -177,22 +178,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor,
split_k_stride_a,
split_k_stride_b,
split_k_offset_a_hack,
split_k_offset_b_hack,
k_batch);
GridwiseGemm::template Run<HasMainKBlockLoop, SplitKOffsetAHack, SplitKOffsetBHack>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
#else
ignore = p_a_grid;
@@ -207,8 +207,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = c_block_cluster_adaptor;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
ignore = k_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -551,7 +549,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
MRepeat,
NRepeat,
FloatC,
CGlobalMemoryDataOperation>();
CGlobalMemoryDataOperation_>();
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
@@ -661,7 +659,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
template <bool HasMainKBlockLoop,
bool SplitKOffsetAHack = false,
bool SplitKOffsetBHack = false>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
@@ -676,8 +676,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const CBlockClusterAdaptor& c_block_cluster_adaptor,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack,
index_t k_batch)
{
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
@@ -688,20 +686,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const index_t k_batch_id = block_work_idx[I0];
const long_index_t split_k_offset_a =
split_k_offset_a_hack ? k_batch_id * split_k_stride_a : 0;
const long_index_t split_k_offset_b =
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
// Use compile-time branching based on template parameters
const long_index_t split_k_offset_a = SplitKOffsetAHack ? k_batch_id * split_k_stride_a : 0;
const long_index_t split_k_offset_b = SplitKOffsetBHack ? k_batch_id * split_k_stride_b : 0;
// When hack is enabled, buffer size equals the stride (calculated from descriptor's
// CalculateOffset method in the device layer). This properly accounts for the
// descriptor's transform pipeline and non-compact strides.
// When hack is disabled, use the full element space size.
const long_index_t a_buffer_size =
split_k_offset_a_hack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
SplitKOffsetAHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
const long_index_t b_buffer_size =
split_k_offset_b_hack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
SplitKOffsetBHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
ignore = k_batch; // k_batch value itself not used in this function
@@ -763,7 +760,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(
split_k_offset_a_hack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
SplitKOffsetAHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0),
@@ -794,7 +791,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(
split_k_offset_b_hack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
SplitKOffsetBHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0),