mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Regression fix for cshuffle
This commit is contained in:
@@ -33,6 +33,111 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dispatch helper function for split-K hack - handles 4-way dispatch based on runtime flags
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid,
|
||||
const FloatB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
void* p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
if(split_k_offset_a_hack && split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, true, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_a_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, true, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, false, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, false, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -84,22 +189,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack,
|
||||
k_batch);
|
||||
DispatchBatchedGemmSplitKHack<GridwiseGemm,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
AGridDesc_B_K0_M_K1,
|
||||
BGridDesc_B_K0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
HasMainKBlockLoop>(
|
||||
p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -474,7 +591,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
false>, // Both true/false give the same occupancy.
|
||||
false>, // HasMainKBlockLoop - both true/false give the same occupancy
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
return std::max(1, max_occupancy);
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -149,7 +150,9 @@ template <typename GridwiseGemm,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
bool HasMainKBlockLoop,
|
||||
bool SplitKOffsetAHack,
|
||||
bool SplitKOffsetBHack>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -167,8 +170,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
@@ -177,22 +178,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack,
|
||||
k_batch);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, SplitKOffsetAHack, SplitKOffsetBHack>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -207,8 +207,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = c_block_cluster_adaptor;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
ignore = k_batch;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
@@ -551,7 +549,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation>();
|
||||
CGlobalMemoryDataOperation_>();
|
||||
}
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
@@ -661,7 +659,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
template <bool HasMainKBlockLoop,
|
||||
bool SplitKOffsetAHack = false,
|
||||
bool SplitKOffsetBHack = false>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -676,8 +676,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
@@ -688,20 +686,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
const long_index_t split_k_offset_a =
|
||||
split_k_offset_a_hack ? k_batch_id * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b =
|
||||
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
|
||||
// Use compile-time branching based on template parameters
|
||||
const long_index_t split_k_offset_a = SplitKOffsetAHack ? k_batch_id * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = SplitKOffsetBHack ? k_batch_id * split_k_stride_b : 0;
|
||||
|
||||
// When hack is enabled, buffer size equals the stride (calculated from descriptor's
|
||||
// CalculateOffset method in the device layer). This properly accounts for the
|
||||
// descriptor's transform pipeline and non-compact strides.
|
||||
// When hack is disabled, use the full element space size.
|
||||
const long_index_t a_buffer_size =
|
||||
split_k_offset_a_hack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
|
||||
SplitKOffsetAHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
|
||||
|
||||
const long_index_t b_buffer_size =
|
||||
split_k_offset_b_hack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
|
||||
SplitKOffsetBHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
|
||||
|
||||
ignore = k_batch; // k_batch value itself not used in this function
|
||||
|
||||
@@ -763,7 +760,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
true>(
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
make_multi_index(
|
||||
split_k_offset_a_hack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
SplitKOffsetAHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
@@ -794,7 +791,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
true>(
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
make_multi_index(
|
||||
split_k_offset_b_hack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
SplitKOffsetBHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
|
||||
Reference in New Issue
Block a user