fix mfma_int8 on MI300

This commit is contained in:
Jing Zhang
2023-02-28 20:20:46 +00:00
parent 3973caa48c
commit 1b409ffe5b
43 changed files with 214 additions and 135 deletions

View File

@@ -51,8 +51,8 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx1030__) || defined(__gfx940__) // for GPU code
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \
defined(__gfx940__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
@@ -65,10 +65,14 @@
#define CK_USE_AMD_MFMA
#endif
#if (defined(__gfx90a__) || defined(__gfx940__))
#if(defined(__gfx90a__) || defined(__gfx940__))
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#if defined(__gfx940__)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
@@ -94,7 +98,7 @@
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if (defined(__gfx90a__) || defined(__gfx940__)) // for GPU code
#if(defined(__gfx90a__) || defined(__gfx940__)) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0

View File

@@ -47,7 +47,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -416,7 +417,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -43,7 +43,8 @@ __global__ void
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
@@ -678,7 +679,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -56,7 +56,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
@@ -938,7 +939,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -56,7 +56,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
@@ -839,7 +840,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -74,7 +74,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);

View File

@@ -60,7 +60,8 @@ __global__ void
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -588,7 +589,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -83,7 +83,8 @@ __global__ void
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
@@ -579,7 +580,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -68,7 +68,8 @@ __global__ void
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -790,7 +791,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -59,7 +59,8 @@ __global__ void
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);

View File

@@ -62,7 +62,8 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -636,7 +637,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -62,7 +62,8 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -612,7 +613,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -75,7 +75,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);

View File

@@ -52,7 +52,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -581,7 +582,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -55,7 +55,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);

View File

@@ -51,7 +51,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -456,7 +457,8 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -63,7 +63,8 @@ __global__ void
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
@@ -851,7 +852,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -60,7 +60,8 @@ __global__ void
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -554,7 +555,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -51,7 +51,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -490,7 +491,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -574,7 +574,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -648,7 +648,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -37,7 +37,8 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
@@ -703,7 +704,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}

View File

@@ -130,7 +130,8 @@ __global__ void
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);

View File

@@ -78,7 +78,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);

View File

@@ -155,7 +155,8 @@ __global__ void
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);

View File

@@ -135,7 +135,8 @@ __global__ void
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);

View File

@@ -38,7 +38,8 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();

View File

@@ -66,7 +66,8 @@ __global__ void
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,

View File

@@ -54,7 +54,8 @@ __global__ void
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,

View File

@@ -44,7 +44,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,

View File

@@ -57,7 +57,8 @@ __global__ void
const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue

View File

@@ -165,7 +165,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);

View File

@@ -44,7 +44,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,

View File

@@ -43,7 +43,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,

View File

@@ -42,7 +42,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);

View File

@@ -44,7 +44,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);

View File

@@ -46,7 +46,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(

View File

@@ -49,7 +49,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(

View File

@@ -53,7 +53,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(

View File

@@ -9,7 +9,6 @@
namespace ck {
#if (defined(__gfx908__) || defined(__gfx90a__))
enum struct MfmaInstr
{
mfma_f32_32x32x1xf32 = 0,
@@ -28,30 +27,10 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16,
mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8,
mfma_f64_16x16x4f64
};
#elif (defined(__gfx940__))
enum struct MfmaInstr
{
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32,
mfma_f32_16x16x4xf32,
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16,
mfma_f32_16x16x16f16,
mfma_f32_32x32x8bf16_1k,
mfma_f32_16x16x16bf16_1k,
mfma_f32_32x32x4bf16,
mfma_f32_16x16x8bf16,
mfma_i32_32x32x16i8,
mfma_i32_16x16x16i8,
mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64
};
#endif
template <MfmaInstr instr>
struct mfma_type;
@@ -365,7 +344,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
}
};
#if (defined(__gfx908__) || defined(__gfx90a__))
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
{
@@ -387,29 +365,6 @@ struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#elif (defined(__gfx940__))
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x16i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_32x32x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <>
struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
@@ -433,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x16i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_32x32x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_16x16x32i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{
@@ -571,25 +570,29 @@ struct MfmaSelector
#endif
}
#if (defined(__gfx908__) || defined(__gfx90a__))
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x8i8;
}
#elif (defined(__gfx940__))
#if defined(CK_USE_AMD_MFMA_GFX940)
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
#endif
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x32i8;
}
#else
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x8i8;
}
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x16i8;
}
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};

View File

@@ -259,7 +259,6 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
}
};
#if (defined(__gfx908__) || defined(__gfx90a__))
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x8i8;
@@ -278,26 +277,6 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
0);
}
};
#elif (defined(__gfx940__))
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;
template <>
struct intrin_mfma_i32_32x32x16i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int32_t>(reg_a),
bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
0);
}
};
#endif
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;
@@ -318,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;
template <>
struct intrin_mfma_i32_32x32x16i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x32i8;
template <>
struct intrin_mfma_i32_16x16x32i8<16, 16>
{
template <class FloatC>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;

View File

@@ -898,6 +898,8 @@ struct vector_type<T, 256>
}
};
using int64_t = long;
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;

View File

@@ -10,8 +10,8 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx90a" \
-D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx940" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}