refactotred int8 tuples

This commit is contained in:
Wojciech Laskowski
2025-12-08 16:50:58 +00:00
parent 24b1a08444
commit 0688e667df
12 changed files with 156 additions and 374 deletions

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
// "2x" instances
// DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 128, 32, 32, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, // Assert broken
// "part 2" instances
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 256, 32, 8, 8, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// AGPR Spill when use permuted lds layout. so, use padding for these two.
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,80 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, int32_t, I8, DsDataTypes, I8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -29,7 +29,6 @@
#include "grouped_convolution_forward_wmma.inc"
#endif
#include "grouped_convolution_forward_wmma_cshufflev3.inc"
#include "grouped_convolution_forward_comp_wmma_cshufflev3.inc"
#include "grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc"
#include "grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc"
#endif
@@ -832,6 +831,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instances(
op_ptrs);
}
#endif
}
@@ -872,12 +880,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
// add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_int8_instances(
// op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances(
op_ptrs);
}
#endif
}
@@ -1079,6 +1081,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA

View File

@@ -86,6 +86,20 @@ void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_i
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,

View File

@@ -1,32 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -90,6 +90,22 @@ void add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_f16_instanc
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
@@ -336,6 +352,22 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f8_insta
F8>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NGCDHW,
GKCZYX,
Empty_Tuple,
NGKDHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,

View File

@@ -114,6 +114,7 @@ set(GROUPED_CONV2D_FWD
# GNHWC, GKYXC, GNHWK
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_f16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instance.cpp
# NHWGC, GKYXC, NHWGK
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
@@ -122,9 +123,6 @@ set(GROUPED_CONV2D_FWD
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_ngchw_gkyxc_ngkhw_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_ngchw_gkyxc_ngkhw_f16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_ngchw_gkyxc_ngkhw_int8_instance.cpp
#comp
wmma/comp/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)
@@ -178,14 +176,6 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
TEMPLATE_FILE wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/wmma/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in
@@ -195,8 +185,8 @@ generate_sharded_instantiations(
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
TEMPLATE_FILE wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in
INSTANCES_NAME device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances
TEMPLATE_FILE wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/wmma/mem

View File

@@ -2,19 +2,18 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_comp_instances(
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GNHWC,
GKYXC,
Empty_Tuple,
NHWGK,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
@@ -25,39 +24,36 @@ void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_comp_i
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
ConvFwdOddC>{});
}
} // namespace instance

View File

@@ -2,12 +2,12 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/utility/filter_tuple.hpp"
namespace ck::tensor_operation::device::instance {
using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances =
using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances =
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,58 +22,54 @@ using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter
PassThrough>>>;
template <int Shards, int ShardIndex>
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances_shard(
device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances& instances)
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances_shard(
device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances& instances)
{
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault,
Interwave>,
ConvFwdDefault>,
Shards,
ShardIndex>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0,
Interwave>,
ConvFwd1x1P0>,
Shards,
ShardIndex>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0,
Interwave>,
ConvFwd1x1S1P0>,
Shards,
ShardIndex>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC,
Interwave>,
ConvFwdOddC>,
Shards,
ShardIndex>{});
}

View File

@@ -74,6 +74,7 @@ set(GROUPED_CONV3D_FWD
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)
@@ -185,37 +186,6 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/wmma
)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances
TEMPLATE_FILE wmma/mem/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV3D_FWD
OUTPUT_DIR ${GENERATED_DIR}/wmma/mem
)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances
TEMPLATE_FILE wmma/mem/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV3D_FWD
OUTPUT_DIR ${GENERATED_DIR}/wmma/mem
)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
TEMPLATE_FILE wmma/comp/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV3D_FWD
OUTPUT_DIR ${GENERATED_DIR}/wmma/comp
)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_comp_instances
TEMPLATE_FILE wmma/comp/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV3D_FWD
OUTPUT_DIR ${GENERATED_DIR}/wmma/comp
)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp)

View File

@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_gndhwc_gkzyxc_gndhwk_int8_ins
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
ConvFwdDefault>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
ConvFwd1x1P0>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
ConvFwd1x1S1P0>{});
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
ConvFwd1x1S1P0>{});
}
} // namespace instance

View File

@@ -1,81 +1,55 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
#include "ck/utility/filter_tuple.hpp"
namespace ck::tensor_operation::device::instance {
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances =
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NGCDHW,
GKCZYX,
Empty_Tuple,
NHWGK,
NGKDHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>;
template <int Shards, int ShardIndex>
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances_shard(
device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances& instances)
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
NHWGC,
GKYXC,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3,
NGCDHW,
GKCZYX,
Empty_Tuple,
NHWGK,
ConvFwdDefault,
Intrawave>,
Shards,
ShardIndex>{});
NGKDHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
NHWGC,
GKYXC,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3,
NGCDHW,
GKCZYX,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0,
Intrawave>,
Shards,
ShardIndex>{});
NGKDHW,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
NHWGC,
GKYXC,
device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3,
NGCDHW,
GKCZYX,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0,
Intrawave>,
Shards,
ShardIndex>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<
device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC,
Intrawave>,
Shards,
ShardIndex>{});
NGKDHW,
ConvFwd1x1S1P0>{});
}
} // namespace ck::tensor_operation::device::instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck