mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] Add backward weight instance traits for xdl cshuffle. (#3143)
* Add backward weight instance traits for xdl cshuffle. To keep instance test file sizes reasonable, we start a new test_bwd_weight_instances_traits.cpp test file. * Fix copyright notices. * Remove (c) symbol, replace with (C). Having UTF-8 in source caused an error with code generation.
This commit is contained in:
@@ -20,6 +20,7 @@ endfunction()
|
||||
add_ck_builder_test(test_ckb_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_fwd_instance_traits.cpp
|
||||
test_bwd_weight_instance_traits.cpp
|
||||
test_instance_traits_util.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
@@ -30,7 +31,8 @@ add_ck_builder_test(test_ckb_get_instance_string
|
||||
test_get_instance_string_fwd_grp_conv.cpp
|
||||
test_get_instance_string_fwd_grp_conv_large_tensor.cpp
|
||||
test_get_instance_string_fwd_grp_conv_wmma.cpp
|
||||
test_get_instance_string_fwd_grp_conv_dl.cpp)
|
||||
test_get_instance_string_fwd_grp_conv_dl.cpp
|
||||
test_get_instance_string_bwd_weight_grp_conv_xdl.cpp)
|
||||
|
||||
# Testing the fwd convolution builder requires kernel compilation.
|
||||
# To enable parallel compilation, the individual tests are split into separate files.
|
||||
|
||||
112
experimental/builder/test/test_bwd_weight_instance_traits.cpp
Normal file
112
experimental/builder/test/test_bwd_weight_instance_traits.cpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(InstanceTraits, BwdWeightXdlCShuffleInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
ck::half_t, // ComputeTypeA
|
||||
ck::half_t, // ComputeTypeB
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // InLayout
|
||||
",GKYXC" // WeiLayout
|
||||
",GNHWK" // OutLayout
|
||||
",fp16" // InDataType
|
||||
",fp16" // WeiDataType
|
||||
",fp16" // OutDataType
|
||||
",fp32" // AccDataType
|
||||
",PassThrough" // InElementwiseOperation
|
||||
",PassThrough" // WeiElementwiseOperation
|
||||
",PassThrough" // OutElementwiseOperation
|
||||
",Default" // ConvBackwardWeightSpecialization
|
||||
",256" // BlockSize
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",4" // K0PerBlock
|
||||
",8" // K1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",2" // MXdlPerWave
|
||||
",2" // NXdlPerWave
|
||||
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_K1
|
||||
",false" // ABlockLdsAddExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_K1
|
||||
",false" // BBlockLdsAddExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CBlockTransferClusterLengths
|
||||
",8" // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
",fp16" // ComputeTypeA
|
||||
",fp16" // ComputeTypeB
|
||||
",1" // MaxTransposeTransferSrcScalarPerVector
|
||||
",1>"; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@@ -0,0 +1,86 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for backward weight XDL variant
|
||||
TEST(GetInstanceString, ReturnsStringForBwdWeightGrpConvXdlInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple = ck::tensor_operation::device::instance::
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // InLayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // WeiLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // OutLayout
|
||||
ck::tensor_operation::device::instance::
|
||||
ConvBwdWeightDefault>; // ConvBwdWeightSpecialization
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using the most general operator base
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
|
||||
// This corresponds to the configuration with BlockSize=64, MPerBlock=64, NPerBlock=64, etc.
|
||||
std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // InLayout
|
||||
",GKYXC" // WeiLayout
|
||||
",GNHWK" // OutLayout
|
||||
",fp16" // InDataType
|
||||
",fp16" // WeiDataType
|
||||
",fp16" // OutDataType
|
||||
",fp32" // AccDataType
|
||||
",PassThrough" // InElementwiseOperation
|
||||
",PassThrough" // WeiElementwiseOperation
|
||||
",PassThrough" // OutElementwiseOperation
|
||||
",Default" // ConvBackwardWeightSpecialization
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",4" // K0PerBlock
|
||||
",8" // K1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",2" // MXdlPerWave
|
||||
",2" // NXdlPerWave
|
||||
",Seq(1,4,8,2)" // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
",Seq(0,3,1,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(0,2,1,3)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",2" // ABlockTransferSrcScalarPerVector
|
||||
",4" // ABlockTransferDstScalarPerVector_K1
|
||||
",true" // ABlockLdsAddExtraM
|
||||
",Seq(1,4,8,2)" // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
",Seq(0,3,1,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(0,2,1,3)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",2" // BBlockTransferSrcScalarPerVector
|
||||
",4" // BBlockTransferDstScalarPerVector_K1
|
||||
",true" // BBlockLdsAddExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CBlockTransferClusterLengths
|
||||
",2" // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
",fp16" // ComputeTypeA
|
||||
",fp16" // ComputeTypeB
|
||||
",1" // MaxTransposeTransferSrcScalarPerVector
|
||||
",1>"; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
Reference in New Issue
Block a user