diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 27d7e0882a..9058bb63a4 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -468,7 +468,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { continue; } - + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( N, K, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp index 8d222d53dc..f4a6677b3e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp @@ -19,7 +19,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); // clang-format on } // namespace device_reduce_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp index dceb797302..272ae982c1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace device_gemm_instance { using BF16 = ck::bhalf_t; -using F32 = float; +using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,8 +21,9 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = + std::tuple< + // clang-format off //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| @@ -43,8 +44,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp index 33e33b4988..ebcde34546 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace device_gemm_instance { using BF16 = ck::bhalf_t; -using F32 = float; +using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,8 +21,9 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = + std::tuple< + // clang-format off //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| @@ -43,8 +44,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 319db8ea7f..4e35adfeab 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace device_gemm_instance { using BF16 = ck::bhalf_t; -using F32 = float; +using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,8 +21,9 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = + std::tuple< + // clang-format off //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| @@ -43,8 +44,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 4b3524c30e..346b1a4bec 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -47,11 +47,33 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< // // clang-format off -// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -// //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -// //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, 1, 9, S<1, 2, 1, 72>, 2> +// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| +// B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| +// ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| +// BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| +// CBlockTransferClusterLengths| CBlockTransfer| +// //#########################| Type| Type| Type| Type| | | | +// Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | +// XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| +// SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| +// SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| +// _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// //#########################| | | | | | | | +// Operation| Operation| Operation| | | | | | | | +// | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| +// PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | +// PerVector| PerVector_K1| | PerShuffle| PerShuffle| +// _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// //#########################| | | | | | | | | | +// | | | | | | | | | | | | +// | | | | | | | | | | | | +// | | | | | +// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, +// PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, +// 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, +// true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, +// true, 1, 9, S<1, 2, 1, 72>, 2> // // clang-format on // >; diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index eb5ba53571..dd9f79ee41 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -19,7 +19,6 @@ int profile_grouped_gemm(int, char*[]); int main(int argc, char* argv[]) { -#if 0 if(strcmp(argv[1], "gemm") == 0) { return profile_gemm(argc, argv); @@ -86,7 +85,4 @@ int main(int argc, char* argv[]) return 0; } -#else - profile_grouped_gemm(argc, argv); -#endif } diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index b60a496218..8037ee5c08 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -32,10 +32,14 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 0f4f1cbf01..99073bbd8d 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -32,11 +32,15 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(std::vector&); -} +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( + std::vector&); +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index f18055879c..7cf539aa26 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -55,10 +55,10 @@ check_err(const std::vector& out, } bool check_err(const std::vector<_Float16>& out, - const std::vector<_Float16>& ref, - const std::string& msg, - _Float16 rtol = static_cast<_Float16>(1e-3f), - _Float16 atol = static_cast<_Float16>(1e-3f)) + const std::vector<_Float16>& ref, + const std::string& msg, + _Float16 rtol = static_cast<_Float16>(1e-3f), + _Float16 atol = static_cast<_Float16>(1e-3f)) { if(out.size() != ref.size()) { @@ -69,14 +69,14 @@ bool check_err(const std::vector<_Float16>& out, } bool res{true}; - int err_count = 0; - double err = 0; - double max_err = std::numeric_limits<_Float16>::min(); + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits<_Float16>::min(); for(std::size_t i = 0; i < ref.size(); ++i) { double out_ = double(out[i]); double ref_ = double(ref[i]); - err = std::abs(out_ - ref_); + err = std::abs(out_ - ref_); if(err > atol + rtol * std::abs(ref_) || !std::isfinite(out_) || !std::isfinite(ref_)) { max_err = err > max_err ? err : max_err;