mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] Add bwd weight factories (#3509)
* Add placeholder test. * Initial conv bwd weight factory. * Conv builder test refactoring. * Add missing pieces to bwd weight factory. * Improve compile time erros message when no matching factory is found. * Use amcro to ensure automatic macthing between concepts are their string representations. * Improve compile time diagnostics. * Small improvements. * Improve missing member/wrong type compile-time errors. * Improve compile time diagnostics. * Concept bug fixes. * Remove debug assert. * Update algorithm signature diagnostics. * Factory bug fixes. * First functional version of bwd weight conv factory. * Refactor handing of GEMM-K batch template parameter in conv bwd weight factory. * Concept improvements. * Improve concept diagnostics. * Introduve a common size type for concepts. * Update compiletime diagnostics to use the size type. * Update conv specialization enum. * Fix fwd conv builder tests. * Fix smoke tests. * Separate bwd weigth and bwd data tests into separate targets. * Clean-up CK Tile builder tests. * Add bwd weight XDL CShuffle V3 factory. * Build conv bwd weigth v3 instances successfully. * Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. * Test fix. * Add instance traits for bwd weight algorithms. * Add unit tests for instance strings. * Build new instance traits unit tests but exclude WMMA for now. * Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. * Conv bwd weight DL factory. * Final implementation for bwd weight DL factory. * Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance. * Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle * Treat ref algorithm the same way as real algorithms in the dispatcher. * Refactor large tensor support and WMMA configuration. * Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3. * Update Readme. * Fix WMMA bwd weight tests. * Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3. * Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle. * Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle. * Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 * Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and compute types for input and output tensor in bwd weigth convs. * Fix fwd factories after refactoring. * clang-format * Move compile-time diagnostics to a separate branch. * Fix ref algorithm dispatching. * Fix smoke tests. * clang-format * Fix factory for regular WMMA conv bwd weight. * Clarify builder Readme. * Remove obsolete test file. * Fix test after merge. * clang-format * Remove the C++26 extensions. * Unify conv elementwise ops and layout definitions for fwd and bwd directions. * Remove old layout and elementwise ops. * Unify handling of conv tensor types between fwd and bwd directions. * Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank. * Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms. * clang-format --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -15,31 +15,42 @@ using namespace test;
|
||||
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
constexpr DlTransferABC DlFwdTransfer{.a =
|
||||
{
|
||||
.block_transfer = DlBlockTransferAB,
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = DlBlockTransferAB,
|
||||
},
|
||||
.c = {
|
||||
.epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4},
|
||||
}};
|
||||
constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2,
|
||||
.b = DlBlockTransfer_8x1x1x2,
|
||||
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4}};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x64x1{
|
||||
constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{
|
||||
.thread_slice_lengths = {1, 8, 1, 1, 1},
|
||||
.thread_cluster_lengths = {1, 2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {0, 2, 3, 1, 4},
|
||||
.src_access_order = {0, 2, 3, 1, 4},
|
||||
.src_vector_tensor_lengths = {1, 1, 1, 1, 1},
|
||||
.src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 1, 1}};
|
||||
|
||||
constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1,
|
||||
.b = DlBlockTransfer_1x8x1x1x1,
|
||||
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 1}};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -72,7 +83,73 @@ constexpr TransferABC FwdTransfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x64x1_fp8{
|
||||
constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -105,7 +182,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x16x1{
|
||||
constexpr Transfer<> Transfer_4x16x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
@@ -139,7 +216,7 @@ constexpr TransferABC FwdTransfer_4x16x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x32x1{
|
||||
constexpr Transfer<> Transfer_4x32x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
@@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 8}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 32, .n = 32, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_1x1x1{
|
||||
constexpr TileTransfer TileTransfer_1x1x1{
|
||||
.a_scalar_per_vector = 1,
|
||||
.b_scalar_per_vector = 1,
|
||||
.c_scalar_per_vector = 1,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_4x4x4{
|
||||
constexpr TileTransfer TileTransfer_4x4x4{
|
||||
.a_scalar_per_vector = 4,
|
||||
.b_scalar_per_vector = 4,
|
||||
.c_scalar_per_vector = 4,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_8x8x8{
|
||||
constexpr TileTransfer TileTransfer_8x8x8{
|
||||
.a_scalar_per_vector = 8,
|
||||
.b_scalar_per_vector = 8,
|
||||
.c_scalar_per_vector = 8,
|
||||
};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
|
||||
@@ -54,7 +54,7 @@ inline std::string to_string<PipelineScheduler>(PipelineScheduler t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvFwdSpecialization>(ConvFwdSpecialization t)
|
||||
inline std::string to_string<ConvSpecialization>(ConvSpecialization t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
@@ -86,11 +86,20 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseXdlGemm>(GridwiseXdlGemm t)
|
||||
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << ","
|
||||
<< t.m_xdl_per_wave << "," << t.n_xdl_per_wave;
|
||||
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
|
||||
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseFwdXdlGemm>(GridwiseFwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
|
||||
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -104,17 +113,29 @@ inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm>(BlockGemm t)
|
||||
inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockTransfer>(BlockTransfer t)
|
||||
template <size_t ThreadClusterRank>
|
||||
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
if constexpr(ThreadClusterRank == 4)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else if constexpr(ThreadClusterRank == 3)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4,
|
||||
"Unsupported ThreadClusterRank");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -134,14 +155,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<AccessOrder>(AccessOrder t)
|
||||
template <size_t N>
|
||||
inline std::string to_string(AccessOrder<N> t)
|
||||
{
|
||||
return array_to_seq(t.order);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferAB>(TransferAB t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(InputTransfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
@@ -152,7 +173,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferC>(TransferC t)
|
||||
inline std::string to_string<OutputTransfer>(OutputTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
|
||||
@@ -160,8 +181,8 @@ inline std::string to_string<TransferC>(TransferC t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferABC>(TransferABC t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(Transfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -185,7 +206,19 @@ inline std::string to_string<DlThreadCluster>(DlThreadCluster t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer>(DlBlockTransfer t)
|
||||
inline std::string to_string<DlBlockTransfer<4>>(DlBlockTransfer<4> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
<< "," << array_to_seq(t.thread_cluster_arrange_order) << ","
|
||||
<< array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths)
|
||||
<< "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << ","
|
||||
<< array_to_seq(t.dst_vector_tensor_lengths);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer<5>>(DlBlockTransfer<5> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
@@ -206,19 +239,24 @@ inline std::string to_string<DlEpilogue>(DlEpilogue t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferAB>(DlBlockTransferAB t)
|
||||
inline std::string to_string<TransposeParams_>(TransposeParams_ t)
|
||||
{
|
||||
return to_string(t.block_transfer);
|
||||
std::ostringstream oss;
|
||||
oss << t.max_transpose_transfer_src_scalar_per_vector << ","
|
||||
<< t.max_transpose_transfer_dst_scalar_per_vector;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferC>(DlBlockTransferC t)
|
||||
inline std::string to_string<DlTransfer<4>>(DlTransfer<4> t)
|
||||
{
|
||||
return to_string(t.epilogue);
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransferABC>(DlTransferABC t)
|
||||
inline std::string to_string<DlTransfer<5>>(DlTransfer<5> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -234,7 +272,13 @@ inline std::string to_string<ThreadBlock_>(ThreadBlock_ t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<XdlGemm_>(XdlGemm_ t)
|
||||
inline std::string to_string<FwdXdlGemm_>(FwdXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
@@ -245,33 +289,40 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Transfer_>(Transfer_ t)
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
inline std::string to_string(Transfer_<ThreadClusterRank> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecialization_>(ConvSpecialization_ t)
|
||||
inline std::string to_string<ConvSpecializationFwd_>(ConvSpecializationFwd_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.bwd_weight_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Prefetch_>(Prefetch_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << ","
|
||||
<< to_string(t.loop_scheduler);
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm_>(BlockGemm_ t)
|
||||
{
|
||||
return to_string(t.block_gemm);
|
||||
return to_string(t.block_gemm_pipeline);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -287,7 +338,13 @@ inline std::string to_string<DlThreadCluster_>(DlThreadCluster_ t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_>(DlTransfer_ t)
|
||||
inline std::string to_string<DlTransfer_<4>>(DlTransfer_<4> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_<5>>(DlTransfer_<5> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
@@ -299,8 +356,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -309,8 +366,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -320,7 +377,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -332,7 +389,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_>(t));
|
||||
<< to_string(static_cast<DlTransfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -340,7 +397,102 @@ template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t)
|
||||
{
|
||||
return to_string(t.base_algorithm);
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_<5>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user