No more regression, use templates instead

This commit is contained in:
Graner, Johannes
2025-12-04 14:06:55 +00:00
parent cca078ab9c
commit 785be14874
3 changed files with 531 additions and 88 deletions

View File

@@ -35,6 +35,225 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Helper function to dispatch split-K hack for standard kernel (single LDS)
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
if(split_k_offset_a_hack && split_k_offset_b_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_a_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_b_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
// Helper function to dispatch split-K hack for 2lds kernel
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
if(split_k_offset_a_hack && split_k_offset_b_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_a_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_b_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
@@ -80,23 +299,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;
@@ -156,24 +376,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
DispatchSplitKHack_2Lds<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;

View File

@@ -34,6 +34,225 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Helper function to dispatch split-K hack for standard kernel (single LDS)
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
if(split_k_offset_a_hack && split_k_offset_b_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_a_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_b_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
// Helper function to dispatch split-K hack for 2lds kernel
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
if(split_k_offset_a_hack && split_k_offset_b_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_a_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else if(split_k_offset_b_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
@@ -77,23 +296,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;
@@ -158,24 +379,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
DispatchSplitKHack_2Lds<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;

View File

@@ -663,7 +663,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd,
bool SplitKOffsetAHack = false,
bool SplitKOffsetBHack = false>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
@@ -673,13 +675,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0,
const index_t k_batch = 1,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
const index_t k_id = 0,
const index_t k_batch = 1)
{
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
const long_index_t a_space_size_divisor = SplitKOffsetAHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetBHack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
@@ -750,7 +750,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetAHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -781,7 +781,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetBHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
@@ -1030,7 +1030,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd,
bool SplitKOffsetAHack = false,
bool SplitKOffsetBHack = false>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
@@ -1041,13 +1043,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0,
const index_t k_batch = 1,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
const index_t k_id = 0,
const index_t k_batch = 1)
{
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
const long_index_t a_space_size_divisor = SplitKOffsetAHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetBHack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
@@ -1118,7 +1118,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetAHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetBHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),