GemmGemm TNNT instances (#399)

* add gemm_gemm TNNT instance

* sanitize Gemm1KPack

* disable instances that failed validation on mi100

[ROCm/composable_kernel commit: fe52c94c98]
This commit is contained in:
Anthony Chang
2022-09-07 02:38:01 +08:00
committed by GitHub
parent 5643625481
commit 2d119fda7b
6 changed files with 107 additions and 3 deletions

View File

@@ -32,6 +32,20 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Col,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename B0Layout,
typename B1Layout,
@@ -82,6 +96,12 @@ struct DeviceOperationInstanceFactory<
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
}
}
return op_ptrs;
}