Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)

This reverts commit 1a97bde100db0b7b5def711082bd2ea0e0aafc03.

[ROCm/composable_kernel commit: 03b59f8c76]
This commit is contained in:
Illia Silin
2025-09-15 08:27:04 -07:00
committed by GitHub
parent 5c712f856f
commit 8cbf571d53
44 changed files with 175 additions and 1085 deletions

View File

@@ -87,9 +87,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)

View File

@@ -310,14 +310,10 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
return true;
}
template <typename DataType, typename ComputeDataType = DataType>
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
@@ -355,14 +351,10 @@ inline __host__ __device__ constexpr double get_rtol()
}
}
template <typename DataType, typename ComputeDataType = DataType>
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}

View File

@@ -1,85 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
#define EXAMPLE_WITH_COMPUTE_DATATYPE
using F32 = float;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F32;
using ComputeDataType = ck::tf32_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#if USING_DIRECT_LOADS
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer|
// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds|
// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type |
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block|
// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32,
8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1,
1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>;
// clang-format on
#else
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType,
ComputeDataType>;
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
#undef EXAMPLE_WITH_COMPUTE_DATATYPE

View File

@@ -4,11 +4,6 @@
#pragma once
#include "ck/library/utility/validation_common.hpp"
// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
@@ -223,8 +218,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType, ComputeDataType>(),
get_atol<CDataType, ComputeDataType>());
get_rtol<CDataType>(),
get_atol<CDataType>());
#endif
}
@@ -254,8 +249,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_device_ref_result,
"Error: Incorrect results!",
get_rtol<CDataType, ComputeDataType>(),
get_atol<CDataType, ComputeDataType>());
get_rtol<CDataType>(),
get_atol<CDataType>());
}
return pass == true;

View File

@@ -1,5 +1,4 @@
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
@@ -20,4 +19,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1)
endif()
endforeach()
endforeach()

View File

@@ -27,14 +27,10 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <typename DataType, typename GemmType = DataType>
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 5e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
@@ -72,14 +68,10 @@ inline __host__ __device__ constexpr double get_rtol()
}
}
template <typename DataType, typename GemmType = DataType>
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 1e-2;
}
else if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
@@ -124,8 +116,7 @@ template <ck::index_t NDimSpatial,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance,
typename ComputeDataType = OutDataType>
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
@@ -237,11 +228,7 @@ bool run_grouped_conv_fwd(bool do_verification,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
0,
0,
0,
ComputeDataType>();
OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
@@ -262,8 +249,8 @@ bool run_grouped_conv_fwd(bool do_verification,
return ck::utils::check_err(out_device,
out_host,
"Error: incorrect results!",
get_rtol<OutDataType, ComputeDataType>(),
get_atol<OutDataType, ComputeDataType>());
get_rtol<OutDataType>(),
get_atol<OutDataType>());
}
return true;

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#define EXAMPLE_WITH_COMPUTE_DATATYPE
using InDataType = float;
using WeiDataType = float;
using AccDataType = float;
using CShuffleDataType = float;
using OutDataType = float;
using ComputeDataType = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout, // ALayout
WeiLayout, // BLayout
ck::Tuple<>, // DsLayout
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ck::Tuple<>, // DsDataType
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
192, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
3, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CDEBlockTransferScalarPerVector_NPerBlock
ComputeDataType, // AComputeDataType
ComputeDataType, // BComputeDataType
ck::LoopScheduler::Default, // LoopScheduler
1 // NumGroupsToMerge
>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
#undef EXAMPLE_WITH_COMPUTE_DATATYPE

View File

@@ -7,8 +7,6 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#define EXAMPLE_WITH_COMPUTE_DATATYPE
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
@@ -89,5 +87,3 @@ int main(int argc, char* argv[])
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}
#undef EXAMPLE_WITH_COMPUTE_DATATYPE

View File

@@ -3,11 +3,6 @@
#pragma once
// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
@@ -70,17 +65,17 @@ bool run_convnd_fwd_example(int argc, char* argv[])
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>,
ComputeDataType>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
};
namespace ctc = ck::tensor_layout::convolution;

View File

@@ -134,7 +134,5 @@ inline bool is_wmma_supported()
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
}
inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
} // namespace ck
#endif

View File

@@ -180,13 +180,13 @@ check_err(const Range& out,
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
err_count++;
}
}
if(!res)

View File

@@ -49,11 +49,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using ElementDataTypeA =
conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
using ElementDataTypeB =
conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
@@ -69,7 +64,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, false, false>{};
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
@@ -177,11 +172,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"ComputeTypeA and ComputeTypeB must be same when one of them is tf32");
}
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
@@ -307,9 +297,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeA>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
@@ -331,20 +321,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<ElementDataTypeA, KPack> a_thread_vec;
vector_type<ElementDataTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ElementDataTypeA>()(i) = a_thread_buf
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<ElementDataTypeB>()(i) = b_thread_buf
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type_a =
typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -371,7 +361,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ElementDataTypeA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
@@ -381,7 +371,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
ElementDataTypeB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
@@ -455,11 +445,6 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using Base::KPerThread;
using Base::xdlops_gemm;
using ElementDataTypeA =
conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
using ElementDataTypeB =
conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
// 2-wave optimized blockwise gemm
@@ -468,9 +453,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeA>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
@@ -514,22 +499,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ElementDataTypeA, KPack> a_thread_vec;
vector_type<ElementDataTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ElementDataTypeA>()(i) =
a_thread_vec.template AsType<ComputeTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<ElementDataTypeB>()(i) =
b_thread_vec.template AsType<ComputeTypeB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type_a =
typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -578,7 +563,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ElementDataTypeA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -588,7 +573,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
ElementDataTypeB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -637,21 +622,19 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize,
FloatA,
FloatB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack,
ComputeTypeA,
ComputeTypeB,
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{};
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA,
FloatB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
};

View File

@@ -119,9 +119,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer,
ComputeDataType>;
PipelineVer>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -216,14 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false;
}
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;

View File

@@ -1003,20 +1003,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
void Print() const
{
std::cout << "AComputeDataType: " << get_type_name<AComputeDataType>()
<< "; BComputeDataType: " << get_type_name<BComputeDataType>()
<< "; EDataType: " << get_type_name<EDataType>() << std::endl;
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_
<< std::endl;
}
// private:
@@ -1207,6 +1198,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
@@ -1289,6 +1281,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(NeedTransposeKernel)
{
const index_t a_grid_size =
@@ -1693,23 +1686,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
// check Gridwise GEMM
if(get_warp_size() == 64)
{
@@ -1789,28 +1766,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!(ck::get_device_name() == "gfx942"))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "TF32 is enabled on gfx942 only" << std::endl;
}
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
return false;
}

View File

@@ -708,9 +708,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType,
BComputeDataType>();
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -107,10 +107,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -661,27 +659,26 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType_,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType_,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType_,
BComputeDataType_>();
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -144,7 +144,7 @@ template <typename ALayout,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v4,
typename BComputeDataType_ = AComputeDataType_>
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -172,10 +172,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
using AComputeDataType = AComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -576,6 +573,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
@@ -642,10 +640,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType_,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType_,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
@@ -661,9 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType_,
BComputeDataType_>();
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -41,11 +41,11 @@ static constexpr bool scale_mfma_hw_support()
enum struct MfmaInstr
{
mfma_f32_32x32x1f32 = 0,
mfma_f32_16x16x1f32,
mfma_f32_4x4x1f32,
mfma_f32_32x32x2f32,
mfma_f32_16x16x4f32,
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32,
mfma_f32_16x16x4xf32,
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
@@ -78,8 +78,6 @@ enum struct MfmaInstr
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4,
mfma_f32_16x16x8xf32, // tf32
mfma_f32_32x32x4xf32,
// gfx11
wmma_f32_16x16x16_f16,
wmma_f32_16x16x16_bf16,
@@ -100,7 +98,7 @@ template <MfmaInstr instr>
struct mfma_type;
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x1f32>
struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
@@ -122,7 +120,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1f32>
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2f32>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
@@ -144,7 +142,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2f32>
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x4f32>
struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
@@ -166,7 +164,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f32>
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x1f32>
struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
@@ -189,7 +187,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1f32>
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_type<MfmaInstr::mfma_f32_4x4x1f32>
struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
@@ -949,70 +947,6 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
}
};
/**
* num_threads_per_blk == n_per_blk
* num_regs_per_blk * num_input_blks == m_per_blk
* num_regs_per_blk * wave_size == m_per_blk * n_per_blk
*
* group_size * num_groups_per_blk == num_regs_per_blk
*
* num_regs_per_blk is output(CD) register size which is determined by the instruction.
* k_per_blk(K1PerXdlops) is input(AB) register size which is determined by the instruction.
* group_size is corresponding to CD rows mapping. see: GetBeginOfThreadBlk()
*
* is_k_reduction = (k_per_blk == KPerXdlops) ? false: true.
*
* if (is_k_reduction){
* num_output_blks == 1;
* } else {
* num_input_blks == num_output_blks;
* }
*/
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x8xf32>
{
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t num_threads_per_blk = n_per_blk; // 16
static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4
static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2.
static constexpr bool is_k_reduction = true;
// AB register size : 2, register size: 4
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x8xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
{
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t num_threads_per_blk = n_per_blk; // 32
static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16
static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2
static constexpr index_t group_size = 4; // corresponding to CD rows mapping
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;
// AB register size: 2, CD register size: 16
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x4xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
// gfx11
struct mfma_type_gfx11_base
{
@@ -1182,20 +1116,6 @@ struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx12> : public mfma_type_gfx12
}
};
/**
* @class MfmaSelector
* @brief Selects the appropriate MFMA instruction type and configuration for given data types
* and tile sizes on AMD GPUs.
*
* @tparam base_type The base data type for the matrix operation (e.g., float, half_t).
* @tparam MPerXdlops The number of rows per XDLops tile.
* @tparam NPerXdlops The number of columns per XDLops tile.
* @tparam additional_type (Optional) Additional data type for mixed-precision or special cases.
* Defaults to base_type.
* @tparam is_single_rate_mfma (Optional) Whether to use single-rate MFMA instructions.
* Defaults to false.
* @tparam is_scale_mfma (Optional) Whether to use scale MFMA instructions. Defaults to false.
*/
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
@@ -1227,37 +1147,37 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<float, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x1f32;
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
constexpr auto GetMfma<float, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x1f32;
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
constexpr auto GetMfma<float, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x1f32;
return MfmaInstr::mfma_f32_16x16x1xf32;
}
template <>
constexpr auto GetMfma<float, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x1f32;
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
constexpr auto GetMfma<float, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x1f32;
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
constexpr auto GetMfma<float, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x2f32;
return MfmaInstr::mfma_f32_32x32x2xf32;
}
template <>
@@ -1268,22 +1188,10 @@ struct MfmaSelector
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f32_16x16x4f32;
return MfmaInstr::mfma_f32_16x16x4xf32;
#endif
}
template <>
constexpr auto GetMfma<tf32_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x4xf32;
}
template <>
constexpr auto GetMfma<tf32_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x8xf32;
}
template <>
constexpr auto GetMfma<half_t, 64, 64>()
{
@@ -1988,7 +1896,7 @@ struct XdlopsGemm
__device__ __host__ static constexpr index_t GetRegSizePerXdlops()
{
return mfma_instr.num_regs_per_blk;
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
}
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
@@ -1998,12 +1906,12 @@ struct XdlopsGemm
{
static_assert(
is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, tf32_t>::value || is_same<base_type, half_t>::value ||
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value ||
is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
is_same<base_type, bf8_t>::value ||
(is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
"base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)

View File

@@ -1636,45 +1636,4 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
}
};
/******************* tf32 *************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x8xf32;
template <>
struct intrin_mfma_f32_16x16x8xf32<16, 16>
{
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4xf32;
template <>
struct intrin_mfma_f32_32x32x4xf32<32, 32>
{
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
} // namespace ck

View File

@@ -26,7 +26,6 @@ using byte = unsigned char;
using std::byte;
#endif
using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
@@ -462,38 +461,4 @@ using int64_t = long long;
using int64_t = long;
#endif
template <typename T>
inline const char* get_type_name()
{
if constexpr(is_same_v<T, half_t>)
return "fp16";
else if constexpr(is_same_v<T, bhalf_t>)
return "bf16";
else if constexpr(is_same_v<T, tf32_t>)
return "tf32";
else if constexpr(is_same_v<T, int4_t>)
return "int4";
else if constexpr(is_same_v<T, f4_t>)
return "f4";
else if constexpr(is_same_v<T, f6_t>)
return "f6";
else if constexpr(is_same_v<T, bf6_t>)
return "bf6";
else if constexpr(is_same_v<T, f8_t>)
return "f8";
else if constexpr(is_same_v<T, bf8_t>)
return "bf8";
else if constexpr(is_same_v<T, e8m0_bexp_t>)
return "e8m0";
else if constexpr(is_same_v<T, float>)
return "fp32";
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
else
return "unknown";
#else
else
return typeid(T).name();
#endif
}
} // namespace ck

View File

@@ -187,19 +187,6 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
template <typename Y, enable_if_t<is_same_v<Y, ck::tf32_t>, bool> = false>
inline __host__ __device__ constexpr float type_convert(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
u.int32 = u.int32 & 0xffffe000;
return u.fp32;
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)

View File

@@ -59,7 +59,6 @@ template <ck::index_t NDimSpatial,
ck::index_t NumAElementwiseTensor = 0,
ck::index_t NumBElementwiseTensor = 0,
ck::index_t NumDElementwiseTensor = 0,
typename ComputeDataType = InDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator
{
@@ -328,10 +327,8 @@ struct ReferenceConvFwd : public device::BaseOperator
z,
y,
x);
v_acc += ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_wei));
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
}
}
}

View File

@@ -25,12 +25,6 @@ template <typename ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator
{
using ElementDataTypeA =
ck::conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
using ElementDataTypeB =
ck::conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
// Argument
struct Argument : public device::BaseArgument
{
@@ -69,8 +63,8 @@ struct ReferenceGemm : public device::BaseOperator
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0};
ElementDataTypeA v_a{0};
ElementDataTypeB v_b{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
for(int k = 0; k < K; ++k)
{
@@ -83,16 +77,16 @@ struct ReferenceGemm : public device::BaseOperator
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_a = type_convert<ElementDataTypeA>(i4);
v_a = type_convert<ComputeTypeA>(i4);
}
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
v_a = type_convert<ElementDataTypeA>(
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
else
v_a = type_convert<ElementDataTypeA>(
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
@@ -100,7 +94,7 @@ struct ReferenceGemm : public device::BaseOperator
is_same_v<ADataType, f6x32_pk_t> ||
is_same_v<ADataType, bf6x32_pk_t>)
{
v_a = type_convert<ElementDataTypeA>(
v_a = type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size));
}
else
@@ -117,16 +111,16 @@ struct ReferenceGemm : public device::BaseOperator
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = type_convert<ElementDataTypeB>(i4);
v_b = type_convert<ComputeTypeB>(i4);
}
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
v_b = type_convert<ElementDataTypeB>(
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
else
v_b = type_convert<ElementDataTypeB>(
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
@@ -134,7 +128,7 @@ struct ReferenceGemm : public device::BaseOperator
is_same_v<BDataType, f6x32_pk_t> ||
is_same_v<BDataType, bf6x32_pk_t>)
{
v_b = type_convert<ElementDataTypeB>(
v_b = type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size));
}
else
@@ -142,18 +136,8 @@ struct ReferenceGemm : public device::BaseOperator
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
is_same_v<ComputeTypeA, ck::tf32_t>)
{ // only for tf32 now
v_acc +=
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
}
else
{
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c{0};

View File

@@ -38,10 +38,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CDEElementwiseOperation c_element_op)
{
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ElementDataTypeA =
ck::conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
using ElementDataTypeB =
ck::conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
const int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
@@ -50,8 +46,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
{
AccDataType v_acc{0};
ElementDataTypeA v_a{0};
ElementDataTypeB v_b{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
CDataType v_c{0};
for(int k_idx = 0; k_idx < k; ++k_idx)
@@ -80,16 +76,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// apply b_element_op
b_element_op(v_b, p_b_grid[element_idx_b]);
// multiply and accumulate
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
is_same_v<ComputeTypeA, ck::tf32_t>)
{ // only for tf32 now
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
}
else
{
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
}
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
}
// apply c_element_op
c_element_op(v_c, v_acc);

View File

@@ -16,7 +16,6 @@ namespace instance {
// aliasing, for commonly used data type
using F64 = double;
using F32 = float;
using TF32 = ck::tf32_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;

View File

@@ -16,7 +16,6 @@ namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

View File

@@ -24,7 +24,6 @@ using BF8 = ck::bf8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -200,7 +199,7 @@ using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
// 32x32 instance
// 32x32 instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
@@ -285,45 +284,7 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 192, 16, 4, 4, 32, 32, 2, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_f32_tf32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 192, 16, 4, 4, 32, 32, 2, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -443,12 +443,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP8

View File

@@ -215,14 +215,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_XDL

View File

@@ -578,22 +578,6 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_insta
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,

View File

@@ -210,14 +210,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_XDL

View File

@@ -578,22 +578,6 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,

View File

@@ -132,7 +132,6 @@ void add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_insta
PassThrough,
PassThrough,
DynamicUnaryOp>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -160,8 +159,7 @@ template <ck::index_t NumDimSpatial,
typename WeiDataType,
typename DDataTypes,
typename OutDataType,
typename AComputeType,
typename BComputeType = AComputeType>
typename ComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
@@ -175,8 +173,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::DynamicUnaryOp,
AComputeType,
BComputeType>>
ComputeType>>
{
using DeviceOp =
DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
@@ -191,8 +188,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::DynamicUnaryOp,
AComputeType,
BComputeType>;
ComputeType>;
static auto GetInstances()
{
@@ -211,7 +207,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
@@ -248,7 +244,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);

View File

@@ -559,22 +559,6 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,

View File

@@ -7,7 +7,6 @@ set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp
@@ -35,7 +34,7 @@ set(GROUPED_CONV3D_FWD
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp

View File

@@ -1,56 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -2,7 +2,7 @@
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
@@ -29,16 +29,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
@@ -47,7 +38,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
@@ -56,7 +47,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
@@ -67,7 +58,7 @@ generate_sharded_instantiations(
)
# large tensor
# NDHWGC, GKZYXC, NDHWGK
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
@@ -76,7 +67,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
@@ -85,7 +76,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
@@ -96,7 +87,7 @@ generate_sharded_instantiations(
)
# merged groups
# NDHWGC, GKZYXC, NDHWGK
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
@@ -105,7 +96,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
@@ -114,7 +105,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
@@ -125,7 +116,7 @@ generate_sharded_instantiations(
)
#mem
# NDHWGC, GKZYXC, NDHWGK
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
@@ -134,7 +125,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
@@ -143,7 +134,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
@@ -153,7 +144,7 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
# NDHWGC, GKZYXC, NDHWGK
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
@@ -162,7 +153,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
@@ -171,7 +162,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
@@ -182,7 +173,7 @@ generate_sharded_instantiations(
)
#comp
# NDHWGC, GKZYXC, NDHWGK
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
@@ -191,7 +182,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
@@ -200,7 +191,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
@@ -209,7 +200,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
@@ -218,7 +209,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
@@ -227,7 +218,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
@@ -236,7 +227,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances

View File

@@ -1,81 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include "ck/utility/filter_tuple.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances =
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32, F32, F32, F32, F32>,
F32,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp,
TF32,
TF32>>>;
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
template <int Shards, int ShardIndex>
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard(
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& instances)
{
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F32, F32, F32, F32, F32>,
BiasNormalizeInInferClamp>,
Shards,
ShardIndex>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<F32, F32, F32, F32, F32>,
BiasNormalizeInInferClamp>,
Shards,
ShardIndex>{});
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<F32, F32, F32, F32, F32>,
BiasNormalizeInInferClamp>,
Shards,
ShardIndex>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -23,8 +23,6 @@ set(GROUPED_CONV3D_FWD
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
)
)
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})

View File

@@ -1,60 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F32>,
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<F32>,
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<F32>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -23,8 +23,6 @@ set(GROUPED_CONV3D_FWD
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
)
)
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})

View File

@@ -1,60 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
os << "strides {";
LogRange(os, desc.GetStrides(), ", ");
os << "} ";
os << "}";
return os;
}

View File

@@ -21,15 +21,14 @@ enum struct ConvLayout
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F8_F8, // 4
BF8_BF8_F8, // 5
F8_BF8_F8, // 6
BF8_F8_F8, // 7
F32_F32_F32_TF32, // 8
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F8_F8, // 4
BF8_BF8_F8, // 5
F8_BF8_F8, // 6
BF8_F8_F8, // 7
};
enum struct IndexType
@@ -54,7 +53,6 @@ static void print_helper_msg()
<< " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8\n"
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
<< " 8: Input fp32, Weight fp32, Output fp32, Compute tf32\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
@@ -105,7 +103,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using INT8 = int8_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using TF32 = ck::tf32_t;
//
using GNWC = ck::tensor_layout::convolution::GNWC;
@@ -264,10 +261,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
// NHWGC_GKYXC_NHWGK
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -374,10 +367,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
// NGCDHW_GKCZYX_NGKDHW
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -395,10 +384,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
return profile(
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;