mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.
This commit is contained in:
@@ -51,12 +51,24 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
concept HasGemmKBatch = requires(T t) {
|
||||
{ t.k_batch_size};
|
||||
};
|
||||
|
||||
// Concept to check if GEMM k batch size is specified.
|
||||
template <typename T>
|
||||
concept GemmKBatchSizeWellDefinedIfProvided =
|
||||
!HasGemmKBatch<T> || requires(T t) { {t.k_batch_size} -> std::convertible_to<size_t>; };
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<size_t>;
|
||||
{ t.m_n } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
GemmKBatchSizeWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
// Concept for thread cluster dimensions for GEMM output tensor.
|
||||
@@ -91,6 +103,8 @@ concept EpilogueDescriptor = requires(T t) {
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
} || requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
@@ -166,7 +180,6 @@ concept GridwiseFwdXdlGemmDescriptor = requires (T t){
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseBwdXdlGemmDescriptor = requires (T t){
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
@@ -181,6 +181,14 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string {
|
||||
msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n";
|
||||
}
|
||||
|
||||
// k_batch_size is optional - only report if it exists and has wrong type
|
||||
if constexpr (requires(BT t) { t.k_batch_size; }) {
|
||||
using KBatchType = decltype(std::declval<BT>().k_batch_size);
|
||||
constexpr bool convertible = std::convertible_to<KBatchType, size_t>;
|
||||
msg += std::string(" → ") + prefix + ".k_batch_size (optional): " + std::string(CHECK_MARK(convertible)) +
|
||||
std::string(get_type_info<KBatchType>()) + "\n";
|
||||
}
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
||||
@@ -288,7 +296,9 @@ consteval auto diagnose_access_order(const char* prefix) -> std::string {
|
||||
|
||||
if constexpr (requires(AO t) { t.order; }) {
|
||||
using OrderType = decltype(std::declval<AO>().order);
|
||||
constexpr bool convertible = std::convertible_to<OrderType, std::array<size_t, 3>>;
|
||||
constexpr bool convertible_3 = std::convertible_to<OrderType, std::array<size_t, 3>>;
|
||||
constexpr bool convertible_4 = std::convertible_to<OrderType, std::array<size_t, 4>>;
|
||||
constexpr bool convertible = convertible_3 || convertible_4;
|
||||
msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) +
|
||||
std::string(get_type_info<OrderType>()) + "\n";
|
||||
} else {
|
||||
@@ -401,15 +411,6 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string
|
||||
msg += " → T::gridwise_gemm member: [✓]\n";
|
||||
using GG = decltype(T::gridwise_gemm);
|
||||
|
||||
if constexpr (requires(GG t) { t.k0_per_block; }) {
|
||||
using K0Type = decltype(std::declval<GG>().k0_per_block);
|
||||
constexpr bool convertible = std::convertible_to<K0Type, size_t>;
|
||||
msg += " → gridwise_gemm.k0_per_block: " + std::string(CHECK_MARK(convertible)) +
|
||||
std::string(detail::get_type_info<K0Type>()) + "\n";
|
||||
} else {
|
||||
msg += " → gridwise_gemm.k0_per_block: [✗] (missing member)\n";
|
||||
}
|
||||
|
||||
if constexpr (requires(GG t) { t.k1; }) {
|
||||
using K1Type = decltype(std::declval<GG>().k1);
|
||||
constexpr bool convertible = std::convertible_to<K1Type, size_t>;
|
||||
|
||||
@@ -69,7 +69,7 @@ struct ConvBwdWeightXdlFactory
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
GRIDWISE_GEMM.k0_per_block,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
|
||||
@@ -63,9 +63,9 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
return BwdBlockTransfer{
|
||||
.thread_cluster_dims = {1, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {0, block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {0, src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
|
||||
@@ -21,7 +21,7 @@ constexpr auto SIGNATURE =
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x32)
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
|
||||
@@ -35,11 +35,8 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create)
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"Intrawave",
|
||||
"v3",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough"});
|
||||
}
|
||||
|
||||
// TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
|
||||
@@ -49,7 +49,6 @@ static_assert(ckb::GridwiseFwdXdlGemmDescriptor<GridwiseFwdXdlGemm>);
|
||||
|
||||
struct GridwiseBwdXdlGemm
|
||||
{
|
||||
size_t k0_per_block = 0;
|
||||
size_t k1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
@@ -75,13 +74,25 @@ struct BlockGemm
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
template <bool IsBwd = false>
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
size_t k_batch_size;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Specialization for forward (IsBwd = false)
|
||||
template <>
|
||||
struct BlockTransfer<false>
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<>>);
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<true>>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
@@ -111,31 +122,35 @@ struct Epilogue
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
std::array<size_t, ThreadSliceLength> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<>>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
|
||||
|
||||
struct TransferAB
|
||||
template <bool IsBwd = false>
|
||||
struct InputTransfer
|
||||
{
|
||||
BlockTransfer block_transfer;
|
||||
BlockTransfer<IsBwd> block_transfer;
|
||||
LdsTransfer lds_transfer;
|
||||
AccessOrder block_transfer_access_order;
|
||||
AccessOrder src_access_order;
|
||||
std::conditional_t<IsBwd, AccessOrder<4>, AccessOrder<3>> block_transfer_access_order;
|
||||
std::conditional_t<IsBwd, AccessOrder<4>, AccessOrder<3>> src_access_order;
|
||||
};
|
||||
|
||||
struct TransferC
|
||||
struct OutputTransfer
|
||||
{
|
||||
ThreadCluster thread_cluster_dims;
|
||||
Epilogue epilogue;
|
||||
};
|
||||
|
||||
struct TransferABC
|
||||
template <bool IsBwd = false>
|
||||
struct Transfer
|
||||
{
|
||||
TransferAB a;
|
||||
TransferAB b;
|
||||
TransferC c;
|
||||
InputTransfer<IsBwd> a;
|
||||
InputTransfer<IsBwd> b;
|
||||
OutputTransfer c;
|
||||
};
|
||||
|
||||
// DL-specific descriptors
|
||||
@@ -198,9 +213,10 @@ struct WmmaGemm_
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
template <bool IsBwd = false>
|
||||
struct Transfer_
|
||||
{
|
||||
TransferABC transfer;
|
||||
Transfer<IsBwd> transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecializationFwd_
|
||||
@@ -380,7 +396,8 @@ struct ConvAlgorithmTemplate : Components...
|
||||
template <typename T>
|
||||
constexpr auto with_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<Transfer_<>, ConvAlgorithmTemplate> ||
|
||||
std::is_base_of_v<Transfer_<true>, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
@@ -511,13 +528,13 @@ struct ConvAlgorithmTemplate : Components...
|
||||
// Algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_, ConvSpecializationFwd_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_, ConvSpecializationFwd_, BlockGemm_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecializationFwd_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<>, ConvSpecializationFwd_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
@@ -536,6 +553,6 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
|
||||
TileOptimizations_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_, ConvSpecializationBwdWeight_, TransposeParams_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<true>, ConvSpecializationBwdWeight_, TransposeParams_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -39,7 +39,7 @@ constexpr DlTransferABC DlFwdTransfer{.a =
|
||||
.dst_scalar_per_vector = 4},
|
||||
}};
|
||||
|
||||
constexpr TransferABC Transfer_4x64x1{
|
||||
constexpr Transfer<> Transfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -72,28 +72,29 @@ constexpr TransferABC Transfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC BwdTransfer_4x64x1{
|
||||
constexpr bool BWD = true;
|
||||
constexpr Transfer<BWD> BwdTransfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {3, 1, 2},
|
||||
.src_access_order = {2, 1, 3},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {3, 1, 2},
|
||||
.src_access_order = {2, 1, 3},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
@@ -105,7 +106,7 @@ constexpr TransferABC BwdTransfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC Transfer_4x64x1_fp8{
|
||||
constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -138,7 +139,7 @@ constexpr TransferABC Transfer_4x64x1_fp8{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC Transfer_4x16x1{
|
||||
constexpr Transfer<> Transfer_4x16x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
@@ -172,7 +173,7 @@ constexpr TransferABC Transfer_4x16x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC Transfer_4x32x1{
|
||||
constexpr Transfer<> Transfer_4x32x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
@@ -206,7 +207,7 @@ constexpr TransferABC Transfer_4x32x1{
|
||||
};
|
||||
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k0_per_block = 8, .k1 = 8,
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
@@ -244,6 +245,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 8}};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ template <>
|
||||
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.k0_per_block << "," << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
|
||||
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
|
||||
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
@@ -120,10 +120,17 @@ inline std::string to_string<BlockGemm>(BlockGemm t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockTransfer>(BlockTransfer t)
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(BlockTransfer<IsBwd> t)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
if constexpr (IsBwd)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -143,14 +150,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<AccessOrder>(AccessOrder t)
|
||||
template <size_t N>
|
||||
inline std::string to_string(AccessOrder<N> t)
|
||||
{
|
||||
return array_to_seq(t.order);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferAB>(TransferAB t)
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(InputTransfer<IsBwd> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
@@ -161,7 +168,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferC>(TransferC t)
|
||||
inline std::string to_string<OutputTransfer>(OutputTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
|
||||
@@ -169,8 +176,8 @@ inline std::string to_string<TransferC>(TransferC t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferABC>(TransferABC t)
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(Transfer<IsBwd> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -260,8 +267,8 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Transfer_>(Transfer_ t)
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(Transfer_<IsBwd> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
@@ -323,7 +330,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -333,7 +340,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -343,7 +350,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -371,8 +378,9 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuff
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
constexpr bool BWD = true;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<BWD>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user