mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
fixed rdna4 instances
This commit is contained in:
@@ -19,9 +19,9 @@ static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
// static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
// static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
@@ -32,7 +32,7 @@ using device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances =
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>/*,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
@@ -42,7 +42,7 @@ using device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances =
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>*/
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
// static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
// static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
@@ -32,7 +32,7 @@ using device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances =
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>/*,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
@@ -42,7 +42,7 @@ using device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances =
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>*/
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
// static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
// static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
@@ -32,7 +32,7 @@ using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances =
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>/*,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
@@ -42,7 +42,7 @@ using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances =
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>*/
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
//static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
//static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
@@ -32,7 +32,7 @@ using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances =
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>/*,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
@@ -42,7 +42,7 @@ using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances =
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>*/
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ add_subdirectory(conv_tensor_rearrange)
|
||||
add_subdirectory(transpose)
|
||||
add_subdirectory(permute_scale)
|
||||
add_subdirectory(wrapper)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11|gfx12")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
|
||||
add_subdirectory(wmma_op)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "test/wmma_op/wmma_op_util.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType,
|
||||
typename GPUAccType,
|
||||
@@ -62,9 +60,6 @@ int main(int, char*[])
|
||||
pass &= run_test<ck::half_t, ck::half_t, ck::half_t, ck::half_t, 16 >();
|
||||
pass &= run_test<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, float, 16 >();
|
||||
pass &= run_test<int8_t, int8_t, int32_t, int32_t, 8 >();
|
||||
#if defined(CK_USE_WMMA_FP8)
|
||||
pass &= run_test<ck::f8_t, ck::f8_t, float, float, 8 >();
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
@@ -130,7 +130,7 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
for(int ele = 0; ele < 8; ++ele)
|
||||
{
|
||||
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele] = a_temp[ele];
|
||||
@@ -373,33 +373,53 @@ struct TestWmma
|
||||
ck::wmma_op_util::RunHostGEMM<ReferenceGemmInstance>(
|
||||
a, b, c_host, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
// Unsupported types should be filtered out before calling test operator.
|
||||
bool res = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
|
||||
// Act
|
||||
bool is_supported = ck::is_gfx11_supported() &&
|
||||
ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
|
||||
|
||||
if(std::is_same<CDataType, ck::bhalf_t>::value)
|
||||
if(is_supported)
|
||||
{
|
||||
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
|
||||
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
|
||||
res = ck::utils::check_err(
|
||||
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, float>::value ||
|
||||
std::is_same<CDataType, ck::half_t>::value ||
|
||||
std::is_same<CDataType, int8_t>::value ||
|
||||
std::is_same<CDataType, double>::value ||
|
||||
std::is_same<CDataType, f8_t>::value)
|
||||
{
|
||||
// Run with default error thresholds.
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
// Assert
|
||||
bool res = false;
|
||||
if(std::is_same<CDataType, float>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, ck::half_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
|
||||
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
|
||||
res = ck::utils::check_err(
|
||||
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, double>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "UNSUPPORTED CDataType" << std::endl;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user