mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection (#3592)
* added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp * added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle * added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3 * added reflection of max_transpose parameters * fix printing of std optional parameters * fix use of undefined ck::index * added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle * added xdl two stage instance to reflection * added additional variables * added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle, _v3, grouped_conv_two_stage_wmma_cshuffle_v3, * added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3 * added reflection for bwd_weight_wmma_cshuffle * added comments back in * add printed output for optional parameters * update README * fix typo * added num_gemm_k_prefetch_stage and small fixes * modified test string due to reflection of new parameter --------- Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
@@ -259,9 +259,118 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Instance = ckb::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
|
||||
EXPECT_THAT(
|
||||
ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
"│ ├─ Weight Layout: GKYXC\n"
|
||||
"│ ├─ Output Layout: GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 2\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
|
||||
{
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
32, // MPerWMMA
|
||||
32, // NPerXDL
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
4, // NumGroupsToMerge
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector>
|
||||
|
||||
EXPECT_THAT(ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"2D Backward Weight Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
@@ -272,37 +381,146 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Data tile size: 128×128×16\n"
|
||||
" ├─ Struct does not contain optional gemm_padding argument\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Pipeline version: V1\n"
|
||||
" ├─ Pipeline scheduler: DEFAULT\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" └─ Vector access (GMEM write) instruction size: 2"));
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
|
||||
" ├─ Max Transpose dst scalar per vector: 1\n"
|
||||
" └─ Num groups to merge: 4"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
|
||||
{
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
3, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNDHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKZYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNDHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerWmma
|
||||
32, // NPerWmma
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
1, // NummGemmKPrefetchStage
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1, // BlkGemmPipelineVer
|
||||
false>; // BComputeDataType
|
||||
|
||||
EXPECT_THAT(
|
||||
ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"3D Backward Weight Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNDHWC\n"
|
||||
"│ ├─ Weight Layout: GKZYXC\n"
|
||||
"│ ├─ Output Layout: GNDHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 128×128×16\n"
|
||||
" ├─ Struct does not contain optional gemm_padding argument\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V1\n"
|
||||
" ├─ Pipeline scheduler: DEFAULT\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Num gemm k prefetch stage: 1\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString)
|
||||
|
||||
Reference in New Issue
Block a user