mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
No more regression, use templates instead
This commit is contained in:
@@ -35,6 +35,225 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Helper function to dispatch split-K hack for standard kernel (single LDS)
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
if(split_k_offset_a_hack && split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_a_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to dispatch split-K hack for 2lds kernel
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
if(split_k_offset_a_hack && split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_a_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
@@ -80,23 +299,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -156,24 +376,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
DispatchSplitKHack_2Lds<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
|
||||
@@ -34,6 +34,225 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Helper function to dispatch split-K hack for standard kernel (single LDS)
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
if(split_k_offset_a_hack && split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_a_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to dispatch split-K hack for 2lds kernel
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
if(split_k_offset_a_hack && split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_a_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else if(split_k_offset_b_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
@@ -77,23 +296,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -158,24 +379,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
DispatchSplitKHack_2Lds<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
|
||||
@@ -663,7 +663,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd,
|
||||
bool SplitKOffsetAHack = false,
|
||||
bool SplitKOffsetBHack = false>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
@@ -673,13 +675,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetAHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetBHack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
|
||||
@@ -750,7 +750,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetAHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -781,7 +781,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetBHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1030,7 +1030,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd,
|
||||
bool SplitKOffsetAHack = false,
|
||||
bool SplitKOffsetBHack = false>
|
||||
__device__ static void Run_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
@@ -1041,13 +1043,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetAHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetBHack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
|
||||
@@ -1118,7 +1118,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetAHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetBHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
Reference in New Issue
Block a user