Merge remote-tracking branch 'origin/vpietila/ckb-remove-explicit-device-op-flag' into vpietila/ckb-fwd-bwd-instances

This commit is contained in:
Ville Pietilä
2025-11-06 08:53:54 -06:00
11 changed files with 94 additions and 10 deletions

View File

@@ -244,6 +244,19 @@ struct ConvTensorTypes<DataType::I8>
using EDataType = int8_t;
};
template <>
struct ConvTensorTypes<DataType::FP8>
{
using ADataType = ck::f8_t;
using AComputeType = ck::f8_t;
using BDataType = ck::f8_t;
using BComputeType = ck::f8_t;
using CShuffleDataType = ck::f8_t;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = ck::f8_t;
};
template <ElementwiseOperation T>
struct ElementwiseOps
{

View File

@@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_1d_fp16.cpp
conv/test_ckb_conv_fwd_1d_bf16.cpp
conv/test_ckb_conv_fwd_1d_i8.cpp
conv/test_ckb_conv_fwd_1d_i8.cpp
conv/test_ckb_conv_fwd_2d_fp8.cpp
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp

View File

@@ -23,7 +23,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v2_intrawave};

View File

@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::DEFAULT,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v1_intrawave};
@@ -49,7 +49,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::FILTER_3x3,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v5_intrawave};

View File

@@ -21,7 +21,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v3_intrawave};

View File

@@ -21,7 +21,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_128x128x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v4_intrawave};

View File

@@ -0,0 +1,39 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::FP8,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x128x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x2_per_wave,
.block_transfer = FwdBlockTransfer_4x64x1_fp8,
.fwd_specialization = ConvFwdSpecialization::DEFAULT,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages =1,
.num_groups_to_merge = 1,
.loop_scheduler = PipelineScheduler::DEFAULT
};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"256, 256, 128, 32",
"Default"});
}
} // namespace

View File

@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::DEFAULT,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v3_intrawave};

View File

@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_128x128x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v4_intrawave};

View File

@@ -22,7 +22,7 @@ TEST(FwdConvInstances,
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
.gridwise_gemm = FwdGemmParams_Xdl_4x4_per_wave,
.block_transfer = FwdBlockTransfer_4x64_1,
.block_transfer = FwdBlockTransfer_4x64x1,
.fwd_specialization = ConvFwdSpecialization::FILTER_1X1_PAD0,
.gemm_specialization = GemmSpecialization::MNKPadding,
.block_gemm = BlockGemmDesc_v1_intrawave};

View File

@@ -12,7 +12,7 @@ namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr BlockTransferABC FwdBlockTransfer_4x64_1{
constexpr BlockTransferABC FwdBlockTransfer_4x64x1{
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
@@ -37,6 +37,31 @@ constexpr BlockTransferABC FwdBlockTransfer_4x64_1{
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
@@ -90,6 +115,9 @@ constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
@@ -103,6 +131,9 @@ constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8
constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_256x128x32{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128x128x32{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};