mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Enable DPP8 GEMM on Navi3 (#892)
This commit is contained in:
committed by
GitHub
parent
562b4cec48
commit
8f84a01237
@@ -168,7 +168,8 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1030")
|
||||
if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,8 @@ __global__ void
|
||||
#endif
|
||||
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
|
||||
|
||||
Reference in New Issue
Block a user