diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 87d2701e01..ba0dd063a7 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -244,6 +244,19 @@ struct ConvTensorTypes using EDataType = int8_t; }; +template <> +struct ConvTensorTypes +{ + using ADataType = ck::f8_t; + using AComputeType = ck::f8_t; + using BDataType = ck::f8_t; + using BComputeType = ck::f8_t; + using CShuffleDataType = ck::f8_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = ck::f8_t; +}; + template struct ElementwiseOps { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 15e1428419..73612648ee 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_2d_fp8.cpp conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index d6cda1f427..331a842dff 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v2_intrawave}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 31f2976fd0..43df6a1fd9 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::DEFAULT, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v1_intrawave}; @@ -49,7 +49,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::FILTER_3x3, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v5_intrawave}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 6276424a77..f27212202b 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -21,7 +21,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v3_intrawave}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index a390510199..f84d4a705d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -21,7 +21,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_128x128x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v4_intrawave}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp new file mode 100644 index 0000000000..c6c7acb8c2 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp @@ -0,0 +1,39 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP8, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock_256x128x32, + .gridwise_gemm = FwdGemmParams_Xdl_4x2_per_wave, + .block_transfer = FwdBlockTransfer_4x64x1_fp8, + .fwd_specialization = ConvFwdSpecialization::DEFAULT, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages =1, + .num_groups_to_merge = 1, + .loop_scheduler = PipelineScheduler::DEFAULT + }; + + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + "256, 256, 128, 32", + "Default"}); +} + +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 3c59ae24fb..0ff31b6edd 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::DEFAULT, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v3_intrawave}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 14d2811918..7644a9fd4e 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_128x128x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v4_intrawave}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index bce092d5f6..e11170d132 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, .gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave, - .block_transfer = FwdBlockTransfer_4x64_1, + .block_transfer = FwdBlockTransfer_4x64x1, .fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0, .gemm_specialization = GemmSpecialization::MNKPadding, .block_gemm = BlockGemmDesc_v1_intrawave}; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 017af87ab6..8a152b287d 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -12,7 +12,7 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; -constexpr BlockTransferABC FwdBlockTransfer_4x64_1{ +constexpr BlockTransferABC FwdBlockTransfer_4x64x1{ .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, .thread_cluster_dims_c = {.m_block = 1, @@ -37,6 +37,31 @@ constexpr BlockTransferABC FwdBlockTransfer_4x64_1{ .src_access_order_a = {1, 0, 2}, .src_access_order_b = {1, 0, 2}}; +constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{ + .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + constexpr BlockTransferABC FwdBlockTransfer_4x16x1{ .block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, @@ -90,6 +115,9 @@ constexpr BlockTransferABC FwdBlockTransfer_4x32x1{ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; + constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; @@ -103,6 +131,9 @@ constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8 constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr ThreadBlock FwdThreadBlock_256x128x32{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + constexpr ThreadBlock FwdThreadBlock_128x128x32{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}};