mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Merge remote-tracking branch 'origin/vpietila/ckb-remove-explicit-device-op-flag' into vpietila/ckb-fwd-bwd-instances
This commit is contained in:
@@ -244,6 +244,19 @@ struct ConvTensorTypes<DataType::I8>
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::f8_t;
|
||||
};
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
|
||||
@@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
|
||||
@@ -23,7 +23,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x256x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v2_intrawave};
|
||||
|
||||
@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x256x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::DEFAULT,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v1_intrawave};
|
||||
@@ -49,7 +49,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x256x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::FILTER_3x3,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v5_intrawave};
|
||||
|
||||
@@ -21,7 +21,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x256x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v3_intrawave};
|
||||
|
||||
@@ -21,7 +21,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_128x128x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v4_intrawave};
|
||||
|
||||
39
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp
Normal file
39
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP8,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x128x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x2_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1_fp8,
|
||||
.fwd_specialization = ConvFwdSpecialization::DEFAULT,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages =1,
|
||||
.num_groups_to_merge = 1,
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT
|
||||
};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"256, 256, 128, 32",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x256x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::DEFAULT,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v3_intrawave};
|
||||
|
||||
@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_128x128x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v4_intrawave};
|
||||
|
||||
@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock_256x256x32,
|
||||
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
|
||||
.block_transfer = FwdBlockTransfer_4x64_1,
|
||||
.block_transfer = FwdBlockTransfer_4x64x1,
|
||||
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc_v1_intrawave};
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64_1{
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
@@ -37,6 +37,31 @@ constexpr BlockTransferABC FwdBlockTransfer_4x64_1{
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
@@ -90,6 +115,9 @@ constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
|
||||
@@ -103,6 +131,9 @@ constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8
|
||||
constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user