Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.

This commit is contained in:
Ville Pietilä
2025-12-23 10:08:56 -05:00
parent 608266a4ef
commit a1740c614b
8 changed files with 109 additions and 69 deletions

View File

@@ -51,12 +51,24 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
};
template <typename T>
concept HasGemmKBatch = requires(T t) {
{ t.k_batch_size};
};
// Concept to check if GEMM k batch size is specified.
template <typename T>
concept GemmKBatchSizeWellDefinedIfProvided =
!HasGemmKBatch<T> || requires(T t) { {t.k_batch_size} -> std::convertible_to<size_t>; };
// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor = requires(T t) {
{ t.k0 } -> std::convertible_to<size_t>;
{ t.m_n } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
GemmKBatchSizeWellDefinedIfProvided<T>;
};
// Concept for thread cluster dimensions for GEMM output tensor.
@@ -91,6 +103,8 @@ concept EpilogueDescriptor = requires(T t) {
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
} || requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
@@ -166,7 +180,6 @@ concept GridwiseFwdXdlGemmDescriptor = requires (T t){
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseBwdXdlGemmDescriptor = requires (T t){
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};

View File

@@ -181,6 +181,14 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string {
msg += std::string("") + prefix + ".k1: [✗] (missing member)\n";
}
// k_batch_size is optional - only report if it exists and has wrong type
if constexpr (requires(BT t) { t.k_batch_size; }) {
using KBatchType = decltype(std::declval<BT>().k_batch_size);
constexpr bool convertible = std::convertible_to<KBatchType, size_t>;
msg += std::string("") + prefix + ".k_batch_size (optional): " + std::string(CHECK_MARK(convertible)) +
std::string(get_type_info<KBatchType>()) + "\n";
}
return msg;
}
@@ -288,7 +296,9 @@ consteval auto diagnose_access_order(const char* prefix) -> std::string {
if constexpr (requires(AO t) { t.order; }) {
using OrderType = decltype(std::declval<AO>().order);
constexpr bool convertible = std::convertible_to<OrderType, std::array<size_t, 3>>;
constexpr bool convertible_3 = std::convertible_to<OrderType, std::array<size_t, 3>>;
constexpr bool convertible_4 = std::convertible_to<OrderType, std::array<size_t, 4>>;
constexpr bool convertible = convertible_3 || convertible_4;
msg += std::string("") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) +
std::string(get_type_info<OrderType>()) + "\n";
} else {
@@ -401,15 +411,6 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string
msg += " → T::gridwise_gemm member: [✓]\n";
using GG = decltype(T::gridwise_gemm);
if constexpr (requires(GG t) { t.k0_per_block; }) {
using K0Type = decltype(std::declval<GG>().k0_per_block);
constexpr bool convertible = std::convertible_to<K0Type, size_t>;
msg += " → gridwise_gemm.k0_per_block: " + std::string(CHECK_MARK(convertible)) +
std::string(detail::get_type_info<K0Type>()) + "\n";
} else {
msg += " → gridwise_gemm.k0_per_block: [✗] (missing member)\n";
}
if constexpr (requires(GG t) { t.k1; }) {
using K1Type = decltype(std::declval<GG>().k1);
constexpr bool convertible = std::convertible_to<K1Type, size_t>;

View File

@@ -69,7 +69,7 @@ struct ConvBwdWeightXdlFactory
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
GRIDWISE_GEMM.k0_per_block,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,

View File

@@ -63,9 +63,9 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
auto& lds_cfg = TRANSFER.lds_transfer;
return BwdBlockTransfer{
.thread_cluster_dims = {1, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {0, block_order.order[0], block_order.order[1], block_order.order[2]},
.src_access_order = {0, src_order.order[0], src_order.order[1], src_order.order[2]},
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,

View File

@@ -21,7 +21,7 @@ constexpr auto SIGNATURE =
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_256_128x128x32)
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
@@ -35,11 +35,8 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create)
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"Intrawave",
"v3",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough"});
}
// TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd)

View File

@@ -49,7 +49,6 @@ static_assert(ckb::GridwiseFwdXdlGemmDescriptor<GridwiseFwdXdlGemm>);
struct GridwiseBwdXdlGemm
{
size_t k0_per_block = 0;
size_t k1 = 0;
XdlParams xdl_params;
};
@@ -75,13 +74,25 @@ struct BlockGemm
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
// Describe Aand B block transfer thread cluster lengths.
template <bool IsBwd = false>
struct BlockTransfer
{
size_t k0;
size_t m_n;
size_t k1;
size_t k_batch_size;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
// Specialization for forward (IsBwd = false)
template <>
struct BlockTransfer<false>
{
size_t k0;
size_t m_n;
size_t k1;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<>>);
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<true>>);
// Describe C block transfer thread cluster lengths.
struct ThreadCluster
@@ -111,31 +122,35 @@ struct Epilogue
};
static_assert(EpilogueDescriptor<Epilogue>);
template <size_t ThreadSliceLength = 3>
struct AccessOrder
{
std::array<size_t, 3> order;
std::array<size_t, ThreadSliceLength> order;
};
static_assert(AccessOrderDescriptor<AccessOrder>);
static_assert(AccessOrderDescriptor<AccessOrder<>>);
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
struct TransferAB
template <bool IsBwd = false>
struct InputTransfer
{
BlockTransfer block_transfer;
BlockTransfer<IsBwd> block_transfer;
LdsTransfer lds_transfer;
AccessOrder block_transfer_access_order;
AccessOrder src_access_order;
std::conditional_t<IsBwd, AccessOrder<4>, AccessOrder<3>> block_transfer_access_order;
std::conditional_t<IsBwd, AccessOrder<4>, AccessOrder<3>> src_access_order;
};
struct TransferC
struct OutputTransfer
{
ThreadCluster thread_cluster_dims;
Epilogue epilogue;
};
struct TransferABC
template <bool IsBwd = false>
struct Transfer
{
TransferAB a;
TransferAB b;
TransferC c;
InputTransfer<IsBwd> a;
InputTransfer<IsBwd> b;
OutputTransfer c;
};
// DL-specific descriptors
@@ -198,9 +213,10 @@ struct WmmaGemm_
GridwiseWmmaGemm gridwise_gemm;
};
template <bool IsBwd = false>
struct Transfer_
{
TransferABC transfer;
Transfer<IsBwd> transfer;
};
struct ConvSpecializationFwd_
@@ -380,7 +396,8 @@ struct ConvAlgorithmTemplate : Components...
template <typename T>
constexpr auto with_transfer(const T& t) const
{
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<Transfer_<>, ConvAlgorithmTemplate> ||
std::is_base_of_v<Transfer_<true>, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
@@ -511,13 +528,13 @@ struct ConvAlgorithmTemplate : Components...
// Algorithm types
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_, ConvSpecializationFwd_, Prefetch_>;
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_, ConvSpecializationFwd_, BlockGemm_>;
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, BlockGemm_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecializationFwd_, Prefetch_>;
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<>, ConvSpecializationFwd_, Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvAlgorithmTemplate<ThreadBlock_,
@@ -536,6 +553,6 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
TileOptimizations_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_, ConvSpecializationBwdWeight_, TransposeParams_>;
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<true>, ConvSpecializationBwdWeight_, TransposeParams_>;
} // namespace ck_tile::builder::test

View File

@@ -39,7 +39,7 @@ constexpr DlTransferABC DlFwdTransfer{.a =
.dst_scalar_per_vector = 4},
}};
constexpr TransferABC Transfer_4x64x1{
constexpr Transfer<> Transfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
@@ -72,28 +72,29 @@ constexpr TransferABC Transfer_4x64x1{
},
};
constexpr TransferABC BwdTransfer_4x64x1{
constexpr bool BWD = true;
constexpr Transfer<BWD> BwdTransfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {3, 1, 2},
.src_access_order = {2, 1, 3},
.block_transfer_access_order = {0, 3, 1, 2},
.src_access_order = {0, 2, 1, 3},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {3, 1, 2},
.src_access_order = {2, 1, 3},
.block_transfer_access_order = {0, 3, 1, 2},
.src_access_order = {0, 2, 1, 3},
},
.c =
{
@@ -105,7 +106,7 @@ constexpr TransferABC BwdTransfer_4x64x1{
},
};
constexpr TransferABC Transfer_4x64x1_fp8{
constexpr Transfer<> Transfer_4x64x1_fp8{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
@@ -138,7 +139,7 @@ constexpr TransferABC Transfer_4x64x1_fp8{
},
};
constexpr TransferABC Transfer_4x16x1{
constexpr Transfer<> Transfer_4x16x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
@@ -172,7 +173,7 @@ constexpr TransferABC Transfer_4x16x1{
},
};
constexpr TransferABC Transfer_4x32x1{
constexpr Transfer<> Transfer_4x32x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -206,7 +207,7 @@ constexpr TransferABC Transfer_4x32x1{
};
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
.k0_per_block = 8, .k1 = 8,
.k1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
@@ -244,6 +245,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 8}};
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};

View File

@@ -89,7 +89,7 @@ template <>
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
{
std::ostringstream oss;
oss << t.k0_per_block << "," << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
return oss.str();
}
@@ -120,10 +120,17 @@ inline std::string to_string<BlockGemm>(BlockGemm t)
return oss.str();
}
template <>
inline std::string to_string<BlockTransfer>(BlockTransfer t)
template <bool IsBwd>
inline std::string to_string(BlockTransfer<IsBwd> t)
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
if constexpr (IsBwd)
{
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
}
else
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
}
}
template <>
@@ -143,14 +150,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
return oss.str();
}
template <>
inline std::string to_string<AccessOrder>(AccessOrder t)
template <size_t N>
inline std::string to_string(AccessOrder<N> t)
{
return array_to_seq(t.order);
}
template <>
inline std::string to_string<TransferAB>(TransferAB t)
template <bool IsBwd>
inline std::string to_string(InputTransfer<IsBwd> t)
{
std::ostringstream oss;
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
@@ -161,7 +168,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
}
template <>
inline std::string to_string<TransferC>(TransferC t)
inline std::string to_string<OutputTransfer>(OutputTransfer t)
{
std::ostringstream oss;
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
@@ -169,8 +176,8 @@ inline std::string to_string<TransferC>(TransferC t)
return oss.str();
}
template <>
inline std::string to_string<TransferABC>(TransferABC t)
template <bool IsBwd>
inline std::string to_string(Transfer<IsBwd> t)
{
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
@@ -260,8 +267,8 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
return to_string(t.gridwise_gemm);
}
template <>
inline std::string to_string<Transfer_>(Transfer_ t)
template <bool IsBwd>
inline std::string to_string(Transfer_<IsBwd> t)
{
return to_string(t.transfer);
}
@@ -323,7 +330,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -333,7 +340,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -343,7 +350,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -371,8 +378,9 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuff
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
{
std::ostringstream oss;
constexpr bool BWD = true;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
<< "," << to_string(static_cast<Transfer_<BWD>>(t));
return oss.str();
}