[CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection (#3592)

* added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp

* added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle

* added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3

* added reflection of max_transpose parameters

* fix printing of std optional parameters

* fix use of undefined ck::index

* added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle

* added xdl two stage instance to reflection

* added additional variables

* added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle, _v3, grouped_conv_two_stage_wmma_cshuffle_v3,

* added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3

* added reflection for bwd_weight_wmma_cshuffle

* added comments back in

* add printed output for optional parameters

* update README

* fix typo

* added num_gemm_k_prefetch_stage and small fixes

* modified test string due to reflection of new parameter

---------

Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
kabrahamAMD
2026-01-28 17:33:45 +01:00
committed by GitHub
parent bc6083bdd4
commit d6cccf6093
29 changed files with 2555 additions and 489 deletions

View File

@@ -9,6 +9,7 @@ See the [main builder documentation](../README.md) for an overview.
The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation.
1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition.
This template is common for xld and wmma, fwd and backwards weight kernels. std::optional is used for parameters that are only used by some kernels
2. **Description Generation**: The `describe<Instance>()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` (`Description`) object.
@@ -48,6 +49,15 @@ The reflection system (`ckr::describe`) currently supports the following convolu
- **Standard XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle`)
- **Large Tensor XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor`)
- **V3 XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`)
- **V3 WMMA Forward Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
- **XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
- **V3 XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3`)
- **XDL Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle`)
- **Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle`)
- **V3 Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3`)
- **Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffle`)
- **V3 Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffleV3`)
- **V3 Wmma Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
These variants all share similar template parameter structures and are compatible with the current `ConvTraits` implementation.
@@ -59,15 +69,6 @@ The following instance types are **not yet supported** by the reflection system:
- Uses different internal structure with parameters like `K0PerBlock`, `K1`, `M1PerThread`, etc.
- Missing standard members like `kKPerBlock`, `kMPerXDL`, `kAK1`
- **WMMA Variants** (`DeviceGroupedConvFwdMultipleD_Wmma_CShuffle`)
- Uses WMMA-specific parameters like `MPerWmma`, `NPerWmma`, `MRepeat`, `NRepeat`
- Different tile transfer structure incompatible with current `ConvTraits`
- **Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
- Uses different layout naming: `InLayout`, `WeiLayout`, `OutLayout` instead of `ALayout`, `BLayout`, `ELayout`
- Different specialization type: `ConvBackwardWeightSpecialization` vs `ConvForwardSpecialization`
- Missing several members expected by forward convolution traits
### Future Work
To support these additional instance types, the reflection system would need:

View File

@@ -29,30 +29,7 @@ conv::ConvDescription describe()
const auto traits = conv::instance_to_conv_traits<Instance>();
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = traits.spatial_dim,
.direction = traits.direction,
.input_layout = traits.layout[0],
.weight_layout = traits.layout[1],
.output_layout = traits.layout[2],
.data_type = traits.data_type,
.input_element_op = traits.input_element_op,
.weight_element_op = traits.weight_element_op,
.output_element_op = traits.output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = traits.thread_block_size,
.tile_dims = traits.tile_dims,
.warp_gemm = traits.warp_gemm,
.a_tile_transfer = traits.a_tile_transfer,
.b_tile_transfer = traits.b_tile_transfer,
.c_tile_transfer = traits.c_tile_transfer,
.pipeline_version = traits.pipeline_version,
.pipeline_scheduler = traits.pipeline_scheduler,
.conv_specialization = traits.conv_specialization,
.padding = traits.gemm_padding,
},
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
traits, []<typename T = Instance>() { return reflect::instance_string<T>(); });
}
} // namespace ck_tile::reflect

View File

@@ -29,44 +29,12 @@
#include <ck_tile/builder/reflect/description.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
namespace ck_tile::reflect {
namespace conv {
/// @brief Signature information for a convolution operation
/// Contains high-level properties that define the convolution's interface,
/// including dimensionality, data layout, data types, and elementwise operations.
struct ConvSignatureInfo
{
int spatial_dim;
builder::ConvDirection direction;
builder::TensorLayout input_layout;
builder::TensorLayout weight_layout;
builder::TensorLayout output_layout;
builder::DataType data_type;
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
};
/// @brief Algorithm configuration for a convolution kernel
/// Contains low-level implementation details including thread block configuration,
/// tile dimensions, memory access patterns, and pipeline settings.
struct GemmAlgorithmInfo
{
int thread_block_size;
DataTileInfo tile_dims;
WarpGemmParams warp_gemm;
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
builder::ConvSpecialization conv_specialization;
builder::GemmPadding padding;
};
/// @brief Provides human-readable descriptions of convolution kernel instances
/// Generates formatted text descriptions at various levels of detail for
/// understanding and documenting convolution kernel configurations.
@@ -74,16 +42,12 @@ class ConvDescription : public Description
{
public:
/// @brief Constructor for ConvDescription
/// @param sig The signature information containing high-level convolution properties
/// @param algo The algorithm configuration containing low-level implementation details
/// @param traits The ConvTraits object containing all relevant signature and algorithm
/// information
/// @param instance_string_getter A callable that returns a string representation of the
/// instance
ConvDescription(ConvSignatureInfo sig,
GemmAlgorithmInfo algo,
std::function<std::string()> instance_string_getter)
: signature_(std::move(sig)),
algorithm_(std::move(algo)),
instance_string_getter_(std::move(instance_string_getter))
ConvDescription(ConvTraits traits, std::function<std::string()> instance_string_getter)
: traits_(std::move(traits)), instance_string_getter_(std::move(instance_string_getter))
{
}
@@ -92,7 +56,7 @@ class ConvDescription : public Description
std::string brief() const override
{
std::ostringstream oss;
oss << signature_.spatial_dim << "D " << signature_.direction << " convolution";
oss << traits_.spatial_dim << "D " << traits_.direction << " convolution";
return oss.str();
}
@@ -101,39 +65,42 @@ class ConvDescription : public Description
std::string detailed() const override
{
TreeFormatter f;
f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel");
f.writeLine(0, traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", signature_.data_type);
f.writeLine(2, "Input Layout: ", signature_.input_layout);
f.writeLine(2, "Weight Layout: ", signature_.weight_layout);
f.writeLine(2, "Output Layout: ", signature_.output_layout);
f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op);
f.writeLine(2, "Tensor Type: ", traits_.data_type);
f.writeLine(2, "Input Layout: ", traits_.layout[0]);
f.writeLine(2, "Weight Layout: ", traits_.layout[1]);
f.writeLine(2, "Output Layout: ", traits_.layout[2]);
f.writeLine(2, "Input elementwise operation: ", traits_.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", traits_.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", traits_.output_element_op);
f.writeLast(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm_.thread_block_size);
f.writeLine(2, "Thread block size: ", traits_.thread_block_size);
f.writeLine(2,
"Data tile size: ",
algorithm_.tile_dims.m,
traits_.tile_dims.m,
"×",
algorithm_.tile_dims.n,
traits_.tile_dims.n,
"×",
algorithm_.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm_.padding);
f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization);
traits_.tile_dims.k);
if(traits_.gemm_padding)
f.writeLine(
2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
else
f.writeLine(2, "Struct does not contain optional gemm_padding argument");
f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", algorithm_.pipeline_scheduler);
f.writeLine(2, "Pipeline version: ", traits_.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", traits_.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(
3, "subtile size: ", algorithm_.warp_gemm.gemm_m, "×", algorithm_.warp_gemm.gemm_n);
f.writeLine(3, "subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
algorithm_.warp_gemm.m_iter,
traits_.warp_gemm.m_iter,
"×",
algorithm_.warp_gemm.n_iter);
traits_.warp_gemm.n_iter);
// Memory Access section
f.writeLast(2, "Memory access:");
@@ -141,99 +108,126 @@ class ConvDescription : public Description
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm_.a_tile_transfer.tile_dimensions.k0,
traits_.a_tile_transfer.tile_dimensions.k0,
"×",
algorithm_.a_tile_transfer.tile_dimensions.m_or_n,
traits_.a_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm_.a_tile_transfer.tile_dimensions.k1,
traits_.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(4,
"The innermost K subdimension size: ",
algorithm_.a_tile_transfer.transfer_params.k1);
f.writeLine(
4, "The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[0],
traits_.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[1],
traits_.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm_.a_tile_transfer.transfer_params.src_access_order[0],
traits_.a_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm_.a_tile_transfer.transfer_params.src_access_order[1],
traits_.a_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm_.a_tile_transfer.transfer_params.src_access_order[2]);
traits_.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm_.a_tile_transfer.transfer_params.src_vector_dim);
traits_.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm_.a_tile_transfer.transfer_params.src_scalar_per_vector);
traits_.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm_.b_tile_transfer.tile_dimensions.k0,
traits_.b_tile_transfer.tile_dimensions.k0,
"×",
algorithm_.b_tile_transfer.tile_dimensions.m_or_n,
traits_.b_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm_.b_tile_transfer.tile_dimensions.k1,
traits_.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(4,
"The innermost K subdimension size: ",
algorithm_.b_tile_transfer.transfer_params.k1);
f.writeLine(
4, "The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[0],
traits_.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[1],
traits_.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm_.b_tile_transfer.transfer_params.src_access_order[0],
traits_.b_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm_.b_tile_transfer.transfer_params.src_access_order[1],
traits_.b_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm_.b_tile_transfer.transfer_params.src_access_order[2]);
traits_.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm_.b_tile_transfer.transfer_params.src_vector_dim);
traits_.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm_.b_tile_transfer.transfer_params.src_scalar_per_vector);
traits_.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
algorithm_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
algorithm_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
algorithm_.c_tile_transfer.thread_cluster_dims[0],
traits_.c_tile_transfer.thread_cluster_dims[0],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[1],
traits_.c_tile_transfer.thread_cluster_dims[1],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[2],
traits_.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
traits_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLine(4,
"Vector access (GMEM write) instruction size: ",
algorithm_.c_tile_transfer.scalar_per_vector);
traits_.c_tile_transfer.scalar_per_vector);
if(traits_.num_gemm_k_prefetch_stage)
f.writeLine(
2, "Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"num_gemm_k_prefetch_stage parameter");
if(traits_.max_transpose_transfer_src_scalar_per_vector)
f.writeLine(2,
"Max Transpose transfer scr scalar per vector: ",
traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_src_scalar_per_vector parameter");
if(traits_.max_transpose_dst_scalar_per_vector)
f.writeLine(2,
"Max Transpose dst scalar per vector: ",
traits_.max_transpose_dst_scalar_per_vector.value_or(0));
else
f.writeLine(
2,
"Struct does not contain optional max_transpose_dst_scalar_per_vector parameter");
if(traits_.num_groups_to_merge)
f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
else
f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter");
return f.getString();
}
@@ -242,8 +236,7 @@ class ConvDescription : public Description
std::string instance_string() const override { return instance_string_getter_(); }
private:
ConvSignatureInfo signature_;
GemmAlgorithmInfo algorithm_;
ConvTraits traits_;
std::function<std::string()> instance_string_getter_;
};

View File

@@ -88,7 +88,7 @@ struct ConvTraits
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
builder::GemmPadding gemm_padding;
std::optional<builder::GemmPadding> gemm_padding = std::nullopt;
builder::ConvSpecialization conv_specialization;
// --- Algorithm Information ---
@@ -102,8 +102,14 @@ struct ConvTraits
OutputTileTransferInfo c_tile_transfer;
std::optional<int> num_gemm_k_prefetch_stage = std::nullopt;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::optional<int> max_transpose_transfer_src_scalar_per_vector = std::nullopt;
std::optional<int> max_transpose_dst_scalar_per_vector = std::nullopt;
std::optional<int> num_groups_to_merge = std::nullopt;
};
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,57 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,56 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),

View File

@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),

View File

@@ -80,6 +80,22 @@ namespace ck_tile::reflect::conv {
// SECTION 1: ENUM CONVERSIONS
// ============================================================================
// Forward convolution layout concept - checks for A/B/E layout types
template <typename T>
concept HasFwdConvLayouts = requires {
typename T::ALayout;
typename T::BLayout;
typename T::ELayout;
};
// Backwards weight layout concept - checks for In, wei and out layouts
template <typename T>
concept HasBwdWeiLayouts = requires {
typename T::InLayout;
typename T::WeiLayout;
typename T::OutLayout;
};
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value.
@@ -322,12 +338,25 @@ constexpr builder::ConvSpecialization conv_spec()
// Tensor Layouts
// ----------------------------------------------------------------------------
// Helper variable template to check if CK layout enums match
template <typename A,
typename B,
typename E,
typename ExpectedA,
typename ExpectedB,
typename ExpectedE>
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Helper function to report unsupported layout combinations with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
/// @details This consteval function is designed to fail at compile time with a descriptive
/// error message when an unsupported layout combination is encountered.
template <typename A, typename B, typename E, int SpatialDim>
[[noreturn]] consteval void report_unsupported_layout_error()
{
// This will produce a compile-time error with the exception message
throw "Unsupported convolution layout combination detected!\n"
"The combination of ALayout, BLayout, and ELayout template parameters\n"
"is not recognized for the given spatial dimension.\n"
@@ -335,111 +364,99 @@ template <typename A, typename B, typename E, int SpatialDim>
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array<builder::TensorLayout, 3> containing the layouts for:
/// - [0] Input tensor layout
/// - [1] Weight tensor layout
/// - [2] Output tensor layout
/// @details This function examines the Instance's ALayout, BLayout, and ELayout types
/// along with the spatial dimension to determine the appropriate layout configuration.
///
/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions).
/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants.
///
/// @note Compilation will fail with a clear error message if the layout combination
/// is not supported for the given spatial dimension.
///
/// TODO: If we don't check for supported layouts, this function can be simplified.
template <typename Instance>
constexpr std::array<builder::TensorLayout, 3> conv_layout()
template <typename A, typename B, typename E, int kSpatialDim>
constexpr auto conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using A = typename InstTraits::ALayout;
using B = typename InstTraits::BLayout;
using E = typename InstTraits::ELayout;
namespace ctl = ck::tensor_layout::convolution;
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
// Helper to check if layouts match expected types
constexpr auto layouts_match = []<typename ExpA, typename ExpB, typename ExpE>() {
return std::is_same_v<A, ExpA> && std::is_same_v<B, ExpB> && std::is_same_v<E, ExpE>;
};
switch(kSpatialDim)
{
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
return layouts(NGCW, GKXC, NGKW);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
return layouts(NGCW, GKCX, NGKW);
break;
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
return layouts(NGCHW, GKCYX, NGKHW);
break;
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
return layouts(NGCDHW, GKZYXC, NGKDHW);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
// Helper to construct layout array
constexpr auto make_layouts = [](auto in, auto weight, auto out) {
return std::array<builder::TensorLayout, 3>{in, weight, out};
};
// If we reach here, the layout combination is not supported
// Call consteval function to trigger a compile-time error with a clear message
report_unsupported_layout_error<A, B, E, kSpatialDim>();
constexpr int spatial_dim = InstTraits::kSpatialDim;
// This return is unreachable but needed to satisfy the compiler
return layouts(GNHWC, GKYXC, GNHWK);
}
if constexpr(spatial_dim == 1)
{
if constexpr(layouts_match.template operator()<ctl::GNWC, ctl::GKXC, ctl::GNWK>())
return make_layouts(GNWC, GKXC, GNWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>())
return make_layouts(GNWC, GKXC, GNWK);
else if constexpr(layouts_match.template operator()<ctl::NWGC, ctl::GKXC, ctl::NWGK>())
return make_layouts(NWGC, GKXC, NWGK);
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKXC, ctl::NGKW>())
return make_layouts(NGCW, GKXC, NGKW);
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKCX, ctl::NGKW>())
return make_layouts(NGCW, GKCX, NGKW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNWC, GKXC, GNWK); // Unreachable
}
}
else if constexpr(spatial_dim == 2)
{
if constexpr(layouts_match.template operator()<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>())
return make_layouts(GNHWC, GKYXC, GNHWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>())
return make_layouts(GNHWC, GKYXC, GNHWK);
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>())
return make_layouts(NHWGC, GKYXC, NHWGK);
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>())
return make_layouts(NHWGC, GKYXC, NHWGK);
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>())
return make_layouts(NGCHW, GKYXC, NGKHW);
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>())
return make_layouts(NGCHW, GKCYX, NGKHW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
}
}
else if constexpr(spatial_dim == 3)
{
if constexpr(layouts_match.template operator()<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>())
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>())
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
else if constexpr(layouts_match
.template operator()<ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>())
return make_layouts(NDHWGC, GKZYXC, NDHWGK);
else if constexpr(layouts_match
.template operator()<ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>())
return make_layouts(NGCDHW, GKZYXC, NGKDHW);
else if constexpr(layouts_match
.template operator()<ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>())
return make_layouts(NGCDHW, GKCZYX, NGKDHW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable
}
}
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto fwd_conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto bwd_wei_conv_layout()
requires HasBwdWeiLayouts<InstanceTraits<Instance>>
{
using A = typename InstanceTraits<Instance>::InLayout;
using B = typename InstanceTraits<Instance>::WeiLayout;
using E = typename InstanceTraits<Instance>::OutLayout;
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
}
// ----------------------------------------------------------------------------
@@ -447,13 +464,11 @@ constexpr std::array<builder::TensorLayout, 3> conv_layout()
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported data type with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename ADataType>
template <typename DataTypeFromInstance>
[[noreturn]] consteval void report_unsupported_data_type_error()
{
throw "Unsupported data type detected!\n"
"The ADataType is not recognized.\n"
"The DataTypeFromInstance is not recognized.\n"
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
@@ -462,62 +477,44 @@ template <typename ADataType>
"Please verify that your kernel instance uses a supported data type.";
}
/// @brief Derives the data type from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return A builder::DataType enum value representing the input data type.
/// @details This function examines the Instance's ADataType to determine the data type
/// used for the input tensor. The function supports various floating-point and integer
/// types, including tuple types for mixed-precision operations.
///
/// Supported data types include:
/// - FP16 (ck::half_t)
/// - FP16_FP16 (ck::Tuple<ck::half_t, ck::half_t>)
/// - BF16 (ck::bhalf_t)
/// - BF16_BF16 (ck::Tuple<ck::bhalf_t, ck::bhalf_t>)
/// - FP32 (float)
/// - FP32_FP32 (ck::Tuple<float, float>)
/// - FP64 (double)
/// - FP8 (ck::f8_t)
/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t)
/// - I8 (int8_t)
/// - I8_I8 (ck::Tuple<int8_t, int8_t>)
/// - U8 (uint8_t)
template <typename Instance>
/// @brief Derives the data type from a device kernel `Instance` type.
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
// Note: maybe move to types.hpp?
template <typename DataTypeFromInstance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
using enum builder::DataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
if constexpr(std::is_same_v<DataTypeFromInstance, ck::half_t>)
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<ck::half_t, ck::half_t>>)
return FP16_FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bhalf_t>)
return BF16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
return BF16_BF16;
else if constexpr(std::is_same_v<ADataType, float>)
else if constexpr(std::is_same_v<DataTypeFromInstance, float>)
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<float, float>>)
return FP32_FP32;
else if constexpr(std::is_same_v<ADataType, double>)
else if constexpr(std::is_same_v<DataTypeFromInstance, double>)
return FP64;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::f8_t>)
return FP8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bf8_fnuz_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bf8_ocp_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, int8_t>)
return I8;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<int8_t, int8_t>>)
return I8_I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, uint8_t>)
return U8;
else
{
report_unsupported_data_type_error<ADataType>();
report_unsupported_data_type_error<DataTypeFromInstance>();
return FP32; // Unreachable
}
}
@@ -736,4 +733,92 @@ constexpr builder::PipelineScheduler get_pipeline_scheduler()
}
}
// ============================================================================
// SECTION 4: Helper functions for common structures often used in reflection
// ============================================================================
template <typename InstTraits>
constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock)
{
return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0};
}
template <typename InstTraits>
constexpr InputTileTransferInfo
conv_traits_a_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock)
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, .k1 = _k1},
.transfer_params = {.k1 = _k1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
}
template <typename InstTraits>
constexpr InputTileTransferInfo
conv_traits_b_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock)
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, .k1 = _k1},
.transfer_params = {.k1 = _k1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
}
template <typename InstTraits>
constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
{
return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma,
.gemm_n = InstTraits::kNPerWmma,
.m_iter = InstTraits::kMRepeat,
.n_iter = InstTraits::kNRepeat};
}
template <typename InstTraits>
constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
{
return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer()
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
}
} // namespace ck_tile::reflect::conv

View File

@@ -3,6 +3,18 @@
#pragma once
// Fwd instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
// Bwd weight instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"

View File

@@ -62,6 +62,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -158,7 +162,9 @@ struct InstanceTraits<
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -175,13 +181,13 @@ struct InstanceTraits<
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kKPerBlock = KPerBlock;
static constexpr ck::index_t kABK1 = ABK1;
static constexpr ck::index_t kK1 = ABK1;
static constexpr ck::index_t kMPerWmma = MPerWmma;
static constexpr ck::index_t kNPerWmma = NPerWmma;
static constexpr ck::index_t kMRepeat = MRepeat;
@@ -195,27 +201,46 @@ struct InstanceTraits<
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -231,7 +256,7 @@ struct InstanceTraits<
oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -251,30 +276,30 @@ struct InstanceTraits<
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kKPerBlock; // 18. KPerBlock
oss << "," << kABK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kKPerBlock; // 18. KPerBlock
oss << "," << kK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
oss << ","

View File

@@ -59,6 +59,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
@@ -152,7 +156,10 @@ struct InstanceTraits<
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -169,7 +176,7 @@ struct InstanceTraits<
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -188,22 +195,36 @@ struct InstanceTraits<
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
@@ -211,6 +232,9 @@ struct InstanceTraits<
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
// Static member function to generate instance string
static std::string instance_string()
{
@@ -220,7 +244,7 @@ struct InstanceTraits<
oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -240,30 +264,30 @@ struct InstanceTraits<
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kK1; // 19. K1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kK1; // 19. K1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 29.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
oss << ","

View File

@@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag device kernel
struct DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -161,7 +166,9 @@ struct InstanceTraits<
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -176,7 +183,7 @@ struct InstanceTraits<
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -201,12 +208,21 @@ struct InstanceTraits<
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
@@ -215,13 +231,26 @@ struct InstanceTraits<
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -237,7 +266,7 @@ struct InstanceTraits<
oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -255,30 +284,30 @@ struct InstanceTraits<
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kABK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kABK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
oss << ","

View File

@@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -160,7 +165,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -175,7 +182,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -199,26 +206,45 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -234,7 +260,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
oss << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -252,30 +278,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","

View File

@@ -59,6 +59,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -148,8 +153,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
false>> // Use false to match with the default value
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag;
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -164,15 +170,15 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerWMMA = MPerWMMA;
static constexpr ck::index_t kNPerWMMA = NPerWMMA;
static constexpr ck::index_t kMPerWmma = MPerWMMA;
static constexpr ck::index_t kNPerWmma = NPerWMMA;
static constexpr ck::index_t kMRepeat = MRepeat;
static constexpr ck::index_t kNRepeat = NRepeat;
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
@@ -184,26 +190,43 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::LoopScheduler kLoopSched = LoopSched;
static constexpr ck::PipelineVersion kPipelineVer = PipelineVer;
@@ -216,7 +239,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -234,30 +257,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerWMMA; // 18. MPerWMMA
oss << "," << kNPerWMMA; // 19. NPerWMMA
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerWmma; // 18. MPerWMMA
oss << "," << kNPerWmma; // 19. NPerWMMA
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
oss << ","

View File

@@ -62,6 +62,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -156,8 +161,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
MaxTransposeTransferDstScalarPerVector>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3";
using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag;
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -172,13 +178,13 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kKPerBlock = KPerBlock;
static constexpr ck::index_t kABK1 = ABK1;
static constexpr ck::index_t kK1 = ABK1;
static constexpr ck::index_t kMPerWmma = MPerWmma;
static constexpr ck::index_t kNPerWmma = NPerWmma;
static constexpr ck::index_t kMRepeat = MRepeat;
@@ -196,27 +202,46 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -232,7 +257,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -250,30 +275,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kABK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
oss << ","

View File

@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -152,7 +157,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -167,43 +173,63 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerXDL = MPerXDL;
static constexpr ck::index_t kNPerXDL = NPerXDL;
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerXDL = MPerXDL;
static constexpr ck::index_t kNPerXDL = NPerXDL;
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
@@ -224,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -242,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","

View File

@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -150,9 +155,12 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
ComputeTypeA_,
ComputeTypeB_>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag;
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -167,7 +175,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -182,28 +190,48 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
@@ -222,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -240,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","

View File

@@ -79,6 +79,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
} // namespace ck::tensor_operation::device
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle device kernel
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag
{
};
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
@@ -176,6 +181,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
LoopSched,
PipelineVer>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag;
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;

File diff suppressed because it is too large Load Diff

View File

@@ -259,9 +259,118 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Instance = ckb::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
EXPECT_THAT(
ckr::describe<Instance>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Input Layout: GNHWC\n"
"│ ├─ Weight Layout: GKYXC\n"
"│ ├─ Output Layout: GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"└─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 256×256×32\n"
" ├─ Gemm padding: DEFAULT\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V4\n"
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ├─ Vector access (GMEM write) instruction size: 2\n"
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
"parameter\n"
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
" └─ Struct does not contain optional num_groups_to_merge parameter"));
}
// Test printing of optional parameters num_groups_to_merge,
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
{
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // InLayout
ck::tensor_layout::convolution::GKYXC, // WeiLayout
ck::tensor_layout::convolution::GNHWK, // OutLayout
ck::half_t, // InDataType
ck::half_t, // WeiDataType
ck::half_t, // OutDataType
float, // AccDataType
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
Default, // ConvBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // K0PerBlock
8, // AK1
32, // MPerWMMA
32, // NPerXDL
4, // MRepeat
4, // NRepeat
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
1, // ABlockLdsAddExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
1, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
8, // CDEBlockTransferScalarPerVector_NPerBlock_
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
4, // NumGroupsToMerge
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
1, // MaxTransposeTransferSrcScalarPerVector
1>; // MaxTransposeTransferDstScalarPerVector>
EXPECT_THAT(ckr::describe<Instance>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"2D Backward Weight Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Input Layout: GNHWC\n"
@@ -272,37 +381,146 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"└─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 256×256×32\n"
" ├─ Gemm padding: DEFAULT\n"
" ├─ Data tile size: 128×128×16\n"
" ├─ Struct does not contain optional gemm_padding argument\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V4\n"
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Pipeline version: V1\n"
" ├─ Pipeline scheduler: DEFAULT\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" │ ├─ subtile size: 32×32\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 2"));
" ─ Vector access (GMEM write) instruction size: 8\n"
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
" ├─ Max Transpose dst scalar per vector: 1\n"
" └─ Num groups to merge: 4"));
}
// Test printing of optional parameters num_groups_to_merge,
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
{
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
3, // NDimSpatial
ck::tensor_layout::convolution::GNDHWC, // InLayout
ck::tensor_layout::convolution::GKZYXC, // WeiLayout
ck::tensor_layout::convolution::GNDHWK, // OutLayout
ck::half_t, // InDataType
ck::half_t, // WeiDataType
ck::half_t, // OutDataType
float, // AccDataType
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
Default, // ConvBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // K0PerBlock
8, // K1
32, // MPerWmma
32, // NPerWmma
4, // MRepeat
4, // NRepeat
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
1, // ABlockLdsAddExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
1, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
8, // CDEBlockTransferScalarPerVector_NPerBlock_
1, // NummGemmKPrefetchStage
ck::LoopScheduler::Default, // BlkGemmPipeSched
ck::PipelineVersion::v1, // BlkGemmPipelineVer
false>; // BComputeDataType
EXPECT_THAT(
ckr::describe<Instance>().detailed(),
ckt::StringEqWithDiff( //
"3D Backward Weight Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Input Layout: GNDHWC\n"
"│ ├─ Weight Layout: GKZYXC\n"
"│ ├─ Output Layout: GNDHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"└─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 128×128×16\n"
" ├─ Struct does not contain optional gemm_padding argument\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V1\n"
" ├─ Pipeline scheduler: DEFAULT\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 32×32\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ├─ Vector access (GMEM write) instruction size: 8\n"
" ├─ Num gemm k prefetch stage: 1\n"
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
"parameter\n"
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
" └─ Struct does not contain optional num_groups_to_merge parameter"));
}
TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString)