mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection (#3592)
* added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp * added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle * added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3 * added reflection of max_transpose parameters * fix printing of std optional parameters * fix use of undefined ck::index * added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle * added xdl two stage instance to reflection * added additional variables * added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle, _v3, grouped_conv_two_stage_wmma_cshuffle_v3, * added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3 * added reflection for bwd_weight_wmma_cshuffle * added comments back in * add printed output for optional parameters * update README * fix typo * added num_gemm_k_prefetch_stage and small fixes * modified test string due to reflection of new parameter --------- Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
@@ -9,6 +9,7 @@ See the [main builder documentation](../README.md) for an overview.
|
||||
The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation.
|
||||
|
||||
1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition.
|
||||
This template is common for xld and wmma, fwd and backwards weight kernels. std::optional is used for parameters that are only used by some kernels
|
||||
|
||||
2. **Description Generation**: The `describe<Instance>()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` (`Description`) object.
|
||||
|
||||
@@ -48,6 +49,15 @@ The reflection system (`ckr::describe`) currently supports the following convolu
|
||||
- **Standard XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle`)
|
||||
- **Large Tensor XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor`)
|
||||
- **V3 XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`)
|
||||
- **V3 WMMA Forward Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
|
||||
- **XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
|
||||
- **V3 XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3`)
|
||||
- **XDL Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle`)
|
||||
- **Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle`)
|
||||
- **V3 Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3`)
|
||||
- **Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffle`)
|
||||
- **V3 Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffleV3`)
|
||||
- **V3 Wmma Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
|
||||
|
||||
These variants all share similar template parameter structures and are compatible with the current `ConvTraits` implementation.
|
||||
|
||||
@@ -59,15 +69,6 @@ The following instance types are **not yet supported** by the reflection system:
|
||||
- Uses different internal structure with parameters like `K0PerBlock`, `K1`, `M1PerThread`, etc.
|
||||
- Missing standard members like `kKPerBlock`, `kMPerXDL`, `kAK1`
|
||||
|
||||
- **WMMA Variants** (`DeviceGroupedConvFwdMultipleD_Wmma_CShuffle`)
|
||||
- Uses WMMA-specific parameters like `MPerWmma`, `NPerWmma`, `MRepeat`, `NRepeat`
|
||||
- Different tile transfer structure incompatible with current `ConvTraits`
|
||||
|
||||
- **Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
|
||||
- Uses different layout naming: `InLayout`, `WeiLayout`, `OutLayout` instead of `ALayout`, `BLayout`, `ELayout`
|
||||
- Different specialization type: `ConvBackwardWeightSpecialization` vs `ConvForwardSpecialization`
|
||||
- Missing several members expected by forward convolution traits
|
||||
|
||||
### Future Work
|
||||
|
||||
To support these additional instance types, the reflection system would need:
|
||||
|
||||
@@ -29,30 +29,7 @@ conv::ConvDescription describe()
|
||||
const auto traits = conv::instance_to_conv_traits<Instance>();
|
||||
|
||||
return conv::ConvDescription(
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = traits.spatial_dim,
|
||||
.direction = traits.direction,
|
||||
.input_layout = traits.layout[0],
|
||||
.weight_layout = traits.layout[1],
|
||||
.output_layout = traits.layout[2],
|
||||
.data_type = traits.data_type,
|
||||
.input_element_op = traits.input_element_op,
|
||||
.weight_element_op = traits.weight_element_op,
|
||||
.output_element_op = traits.output_element_op,
|
||||
},
|
||||
conv::GemmAlgorithmInfo{
|
||||
.thread_block_size = traits.thread_block_size,
|
||||
.tile_dims = traits.tile_dims,
|
||||
.warp_gemm = traits.warp_gemm,
|
||||
.a_tile_transfer = traits.a_tile_transfer,
|
||||
.b_tile_transfer = traits.b_tile_transfer,
|
||||
.c_tile_transfer = traits.c_tile_transfer,
|
||||
.pipeline_version = traits.pipeline_version,
|
||||
.pipeline_scheduler = traits.pipeline_scheduler,
|
||||
.conv_specialization = traits.conv_specialization,
|
||||
.padding = traits.gemm_padding,
|
||||
},
|
||||
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
|
||||
traits, []<typename T = Instance>() { return reflect::instance_string<T>(); });
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
|
||||
@@ -29,44 +29,12 @@
|
||||
#include <ck_tile/builder/reflect/description.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/tree_formatter.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
namespace conv {
|
||||
|
||||
/// @brief Signature information for a convolution operation
|
||||
/// Contains high-level properties that define the convolution's interface,
|
||||
/// including dimensionality, data layout, data types, and elementwise operations.
|
||||
struct ConvSignatureInfo
|
||||
{
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
builder::TensorLayout input_layout;
|
||||
builder::TensorLayout weight_layout;
|
||||
builder::TensorLayout output_layout;
|
||||
builder::DataType data_type;
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
};
|
||||
|
||||
/// @brief Algorithm configuration for a convolution kernel
|
||||
/// Contains low-level implementation details including thread block configuration,
|
||||
/// tile dimensions, memory access patterns, and pipeline settings.
|
||||
struct GemmAlgorithmInfo
|
||||
{
|
||||
int thread_block_size;
|
||||
DataTileInfo tile_dims;
|
||||
WarpGemmParams warp_gemm;
|
||||
InputTileTransferInfo a_tile_transfer;
|
||||
InputTileTransferInfo b_tile_transfer;
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
builder::ConvSpecialization conv_specialization;
|
||||
builder::GemmPadding padding;
|
||||
};
|
||||
|
||||
/// @brief Provides human-readable descriptions of convolution kernel instances
|
||||
/// Generates formatted text descriptions at various levels of detail for
|
||||
/// understanding and documenting convolution kernel configurations.
|
||||
@@ -74,16 +42,12 @@ class ConvDescription : public Description
|
||||
{
|
||||
public:
|
||||
/// @brief Constructor for ConvDescription
|
||||
/// @param sig The signature information containing high-level convolution properties
|
||||
/// @param algo The algorithm configuration containing low-level implementation details
|
||||
/// @param traits The ConvTraits object containing all relevant signature and algorithm
|
||||
/// information
|
||||
/// @param instance_string_getter A callable that returns a string representation of the
|
||||
/// instance
|
||||
ConvDescription(ConvSignatureInfo sig,
|
||||
GemmAlgorithmInfo algo,
|
||||
std::function<std::string()> instance_string_getter)
|
||||
: signature_(std::move(sig)),
|
||||
algorithm_(std::move(algo)),
|
||||
instance_string_getter_(std::move(instance_string_getter))
|
||||
ConvDescription(ConvTraits traits, std::function<std::string()> instance_string_getter)
|
||||
: traits_(std::move(traits)), instance_string_getter_(std::move(instance_string_getter))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -92,7 +56,7 @@ class ConvDescription : public Description
|
||||
std::string brief() const override
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << signature_.spatial_dim << "D " << signature_.direction << " convolution";
|
||||
oss << traits_.spatial_dim << "D " << traits_.direction << " convolution";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -101,39 +65,42 @@ class ConvDescription : public Description
|
||||
std::string detailed() const override
|
||||
{
|
||||
TreeFormatter f;
|
||||
f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel");
|
||||
f.writeLine(0, traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel");
|
||||
f.writeLine(1, "Signature");
|
||||
f.writeLine(2, "Tensor Type: ", signature_.data_type);
|
||||
f.writeLine(2, "Input Layout: ", signature_.input_layout);
|
||||
f.writeLine(2, "Weight Layout: ", signature_.weight_layout);
|
||||
f.writeLine(2, "Output Layout: ", signature_.output_layout);
|
||||
f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op);
|
||||
f.writeLine(2, "Tensor Type: ", traits_.data_type);
|
||||
f.writeLine(2, "Input Layout: ", traits_.layout[0]);
|
||||
f.writeLine(2, "Weight Layout: ", traits_.layout[1]);
|
||||
f.writeLine(2, "Output Layout: ", traits_.layout[2]);
|
||||
f.writeLine(2, "Input elementwise operation: ", traits_.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", traits_.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", traits_.output_element_op);
|
||||
|
||||
f.writeLast(1, "Algorithm");
|
||||
// Compute Block section
|
||||
f.writeLine(2, "Thread block size: ", algorithm_.thread_block_size);
|
||||
f.writeLine(2, "Thread block size: ", traits_.thread_block_size);
|
||||
f.writeLine(2,
|
||||
"Data tile size: ",
|
||||
algorithm_.tile_dims.m,
|
||||
traits_.tile_dims.m,
|
||||
"×",
|
||||
algorithm_.tile_dims.n,
|
||||
traits_.tile_dims.n,
|
||||
"×",
|
||||
algorithm_.tile_dims.k);
|
||||
f.writeLine(2, "Gemm padding: ", algorithm_.padding);
|
||||
f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization);
|
||||
traits_.tile_dims.k);
|
||||
if(traits_.gemm_padding)
|
||||
f.writeLine(
|
||||
2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
|
||||
else
|
||||
f.writeLine(2, "Struct does not contain optional gemm_padding argument");
|
||||
f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization);
|
||||
// Pipeline section
|
||||
f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version);
|
||||
f.writeLine(2, "Pipeline scheduler: ", algorithm_.pipeline_scheduler);
|
||||
f.writeLine(2, "Pipeline version: ", traits_.pipeline_version);
|
||||
f.writeLine(2, "Pipeline scheduler: ", traits_.pipeline_scheduler);
|
||||
f.writeLine(2, "Warp Gemm parameters: ");
|
||||
f.writeLine(
|
||||
3, "subtile size: ", algorithm_.warp_gemm.gemm_m, "×", algorithm_.warp_gemm.gemm_n);
|
||||
f.writeLine(3, "subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n);
|
||||
f.writeLast(3,
|
||||
"Number of warp gemm iterations: ",
|
||||
algorithm_.warp_gemm.m_iter,
|
||||
traits_.warp_gemm.m_iter,
|
||||
"×",
|
||||
algorithm_.warp_gemm.n_iter);
|
||||
traits_.warp_gemm.n_iter);
|
||||
|
||||
// Memory Access section
|
||||
f.writeLast(2, "Memory access:");
|
||||
@@ -141,99 +108,126 @@ class ConvDescription : public Description
|
||||
f.writeLine(3, "A Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm_.a_tile_transfer.tile_dimensions.k0,
|
||||
traits_.a_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm_.a_tile_transfer.tile_dimensions.m_or_n,
|
||||
traits_.a_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm_.a_tile_transfer.tile_dimensions.k1,
|
||||
traits_.a_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(4,
|
||||
"The innermost K subdimension size: ",
|
||||
algorithm_.a_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
traits_.a_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
traits_.a_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm_.a_tile_transfer.transfer_params.src_access_order[0],
|
||||
traits_.a_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm_.a_tile_transfer.transfer_params.src_access_order[1],
|
||||
traits_.a_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm_.a_tile_transfer.transfer_params.src_access_order[2]);
|
||||
traits_.a_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm_.a_tile_transfer.transfer_params.src_vector_dim);
|
||||
traits_.a_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm_.a_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
traits_.a_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLine(3, "B Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm_.b_tile_transfer.tile_dimensions.k0,
|
||||
traits_.b_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm_.b_tile_transfer.tile_dimensions.m_or_n,
|
||||
traits_.b_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm_.b_tile_transfer.tile_dimensions.k1,
|
||||
traits_.b_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(4,
|
||||
"The innermost K subdimension size: ",
|
||||
algorithm_.b_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
traits_.b_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
traits_.b_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm_.b_tile_transfer.transfer_params.src_access_order[0],
|
||||
traits_.b_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm_.b_tile_transfer.transfer_params.src_access_order[1],
|
||||
traits_.b_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm_.b_tile_transfer.transfer_params.src_access_order[2]);
|
||||
traits_.b_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm_.b_tile_transfer.transfer_params.src_vector_dim);
|
||||
traits_.b_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm_.b_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
traits_.b_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLast(3, "C Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Data shuffle (number of gemm instructions per iteration): ",
|
||||
algorithm_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
|
||||
traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
|
||||
"×",
|
||||
algorithm_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
|
||||
traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution used to store data: ",
|
||||
algorithm_.c_tile_transfer.thread_cluster_dims[0],
|
||||
traits_.c_tile_transfer.thread_cluster_dims[0],
|
||||
"×",
|
||||
algorithm_.c_tile_transfer.thread_cluster_dims[1],
|
||||
traits_.c_tile_transfer.thread_cluster_dims[1],
|
||||
"×",
|
||||
algorithm_.c_tile_transfer.thread_cluster_dims[2],
|
||||
traits_.c_tile_transfer.thread_cluster_dims[2],
|
||||
"×",
|
||||
algorithm_.c_tile_transfer.thread_cluster_dims[3]);
|
||||
f.writeLast(4,
|
||||
traits_.c_tile_transfer.thread_cluster_dims[3]);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM write) instruction size: ",
|
||||
algorithm_.c_tile_transfer.scalar_per_vector);
|
||||
traits_.c_tile_transfer.scalar_per_vector);
|
||||
if(traits_.num_gemm_k_prefetch_stage)
|
||||
f.writeLine(
|
||||
2, "Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0));
|
||||
else
|
||||
f.writeLine(2,
|
||||
"Struct does not contain optional "
|
||||
"num_gemm_k_prefetch_stage parameter");
|
||||
|
||||
if(traits_.max_transpose_transfer_src_scalar_per_vector)
|
||||
f.writeLine(2,
|
||||
"Max Transpose transfer scr scalar per vector: ",
|
||||
traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0));
|
||||
else
|
||||
f.writeLine(2,
|
||||
"Struct does not contain optional "
|
||||
"max_transpose_transfer_src_scalar_per_vector parameter");
|
||||
if(traits_.max_transpose_dst_scalar_per_vector)
|
||||
f.writeLine(2,
|
||||
"Max Transpose dst scalar per vector: ",
|
||||
traits_.max_transpose_dst_scalar_per_vector.value_or(0));
|
||||
else
|
||||
f.writeLine(
|
||||
2,
|
||||
"Struct does not contain optional max_transpose_dst_scalar_per_vector parameter");
|
||||
if(traits_.num_groups_to_merge)
|
||||
f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
|
||||
else
|
||||
f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter");
|
||||
|
||||
return f.getString();
|
||||
}
|
||||
|
||||
@@ -242,8 +236,7 @@ class ConvDescription : public Description
|
||||
std::string instance_string() const override { return instance_string_getter_(); }
|
||||
|
||||
private:
|
||||
ConvSignatureInfo signature_;
|
||||
GemmAlgorithmInfo algorithm_;
|
||||
ConvTraits traits_;
|
||||
std::function<std::string()> instance_string_getter_;
|
||||
};
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ struct ConvTraits
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
|
||||
builder::GemmPadding gemm_padding;
|
||||
std::optional<builder::GemmPadding> gemm_padding = std::nullopt;
|
||||
builder::ConvSpecialization conv_specialization;
|
||||
|
||||
// --- Algorithm Information ---
|
||||
@@ -102,8 +102,14 @@ struct ConvTraits
|
||||
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
|
||||
std::optional<int> num_gemm_k_prefetch_stage = std::nullopt;
|
||||
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
|
||||
std::optional<int> max_transpose_transfer_src_scalar_per_vector = std::nullopt;
|
||||
std::optional<int> max_transpose_dst_scalar_per_vector = std::nullopt;
|
||||
std::optional<int> num_groups_to_merge = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer =
|
||||
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
|
||||
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer =
|
||||
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
|
||||
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer =
|
||||
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer =
|
||||
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::ADataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
|
||||
@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::ADataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::ADataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::ADataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
|
||||
@@ -80,6 +80,22 @@ namespace ck_tile::reflect::conv {
|
||||
// SECTION 1: ENUM CONVERSIONS
|
||||
// ============================================================================
|
||||
|
||||
// Forward convolution layout concept - checks for A/B/E layout types
|
||||
template <typename T>
|
||||
concept HasFwdConvLayouts = requires {
|
||||
typename T::ALayout;
|
||||
typename T::BLayout;
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// Backwards weight layout concept - checks for In, wei and out layouts
|
||||
template <typename T>
|
||||
concept HasBwdWeiLayouts = requires {
|
||||
typename T::InLayout;
|
||||
typename T::WeiLayout;
|
||||
typename T::OutLayout;
|
||||
};
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value.
|
||||
@@ -322,12 +338,25 @@ constexpr builder::ConvSpecialization conv_spec()
|
||||
// Tensor Layouts
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Helper variable template to check if CK layout enums match
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename E,
|
||||
typename ExpectedA,
|
||||
typename ExpectedB,
|
||||
typename ExpectedE>
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Helper function to report unsupported layout combinations with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
/// @details This consteval function is designed to fail at compile time with a descriptive
|
||||
/// error message when an unsupported layout combination is encountered.
|
||||
template <typename A, typename B, typename E, int SpatialDim>
|
||||
[[noreturn]] consteval void report_unsupported_layout_error()
|
||||
{
|
||||
// This will produce a compile-time error with the exception message
|
||||
throw "Unsupported convolution layout combination detected!\n"
|
||||
"The combination of ALayout, BLayout, and ELayout template parameters\n"
|
||||
"is not recognized for the given spatial dimension.\n"
|
||||
@@ -335,111 +364,99 @@ template <typename A, typename B, typename E, int SpatialDim>
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array<builder::TensorLayout, 3> containing the layouts for:
|
||||
/// - [0] Input tensor layout
|
||||
/// - [1] Weight tensor layout
|
||||
/// - [2] Output tensor layout
|
||||
/// @details This function examines the Instance's ALayout, BLayout, and ELayout types
|
||||
/// along with the spatial dimension to determine the appropriate layout configuration.
|
||||
///
|
||||
/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions).
|
||||
/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants.
|
||||
///
|
||||
/// @note Compilation will fail with a clear error message if the layout combination
|
||||
/// is not supported for the given spatial dimension.
|
||||
///
|
||||
/// TODO: If we don't check for supported layouts, this function can be simplified.
|
||||
template <typename Instance>
|
||||
constexpr std::array<builder::TensorLayout, 3> conv_layout()
|
||||
template <typename A, typename B, typename E, int kSpatialDim>
|
||||
constexpr auto conv_layout()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using A = typename InstTraits::ALayout;
|
||||
using B = typename InstTraits::BLayout;
|
||||
using E = typename InstTraits::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
// Helper to check if layouts match expected types
|
||||
constexpr auto layouts_match = []<typename ExpA, typename ExpB, typename ExpE>() {
|
||||
return std::is_same_v<A, ExpA> && std::is_same_v<B, ExpB> && std::is_same_v<E, ExpE>;
|
||||
};
|
||||
switch(kSpatialDim)
|
||||
{
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
return layouts(NGCW, GKXC, NGKW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
|
||||
return layouts(NGCW, GKCX, NGKW);
|
||||
break;
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKCYX, NGKHW);
|
||||
break;
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
|
||||
// Helper to construct layout array
|
||||
constexpr auto make_layouts = [](auto in, auto weight, auto out) {
|
||||
return std::array<builder::TensorLayout, 3>{in, weight, out};
|
||||
};
|
||||
// If we reach here, the layout combination is not supported
|
||||
// Call consteval function to trigger a compile-time error with a clear message
|
||||
report_unsupported_layout_error<A, B, E, kSpatialDim>();
|
||||
|
||||
constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
// This return is unreachable but needed to satisfy the compiler
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
}
|
||||
|
||||
if constexpr(spatial_dim == 1)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNWC, ctl::GKXC, ctl::GNWK>())
|
||||
return make_layouts(GNWC, GKXC, GNWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>())
|
||||
return make_layouts(GNWC, GKXC, GNWK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NWGC, ctl::GKXC, ctl::NWGK>())
|
||||
return make_layouts(NWGC, GKXC, NWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKXC, ctl::NGKW>())
|
||||
return make_layouts(NGCW, GKXC, NGKW);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKCX, ctl::NGKW>())
|
||||
return make_layouts(NGCW, GKCX, NGKW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNWC, GKXC, GNWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else if constexpr(spatial_dim == 2)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>())
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>())
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>())
|
||||
return make_layouts(NHWGC, GKYXC, NHWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>())
|
||||
return make_layouts(NHWGC, GKYXC, NHWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>())
|
||||
return make_layouts(NGCHW, GKYXC, NGKHW);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>())
|
||||
return make_layouts(NGCHW, GKCYX, NGKHW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else if constexpr(spatial_dim == 3)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>())
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>())
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>())
|
||||
return make_layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>())
|
||||
return make_layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>())
|
||||
return make_layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
|
||||
}
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
|
||||
template <typename Instance>
|
||||
constexpr auto fwd_conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto bwd_wei_conv_layout()
|
||||
requires HasBwdWeiLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
|
||||
using A = typename InstanceTraits<Instance>::InLayout;
|
||||
using B = typename InstanceTraits<Instance>::WeiLayout;
|
||||
using E = typename InstanceTraits<Instance>::OutLayout;
|
||||
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -447,13 +464,11 @@ constexpr std::array<builder::TensorLayout, 3> conv_layout()
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported data type with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename ADataType>
|
||||
template <typename DataTypeFromInstance>
|
||||
[[noreturn]] consteval void report_unsupported_data_type_error()
|
||||
{
|
||||
throw "Unsupported data type detected!\n"
|
||||
"The ADataType is not recognized.\n"
|
||||
"The DataTypeFromInstance is not recognized.\n"
|
||||
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
|
||||
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
|
||||
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
|
||||
@@ -462,62 +477,44 @@ template <typename ADataType>
|
||||
"Please verify that your kernel instance uses a supported data type.";
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::DataType enum value representing the input data type.
|
||||
/// @details This function examines the Instance's ADataType to determine the data type
|
||||
/// used for the input tensor. The function supports various floating-point and integer
|
||||
/// types, including tuple types for mixed-precision operations.
|
||||
///
|
||||
/// Supported data types include:
|
||||
/// - FP16 (ck::half_t)
|
||||
/// - FP16_FP16 (ck::Tuple<ck::half_t, ck::half_t>)
|
||||
/// - BF16 (ck::bhalf_t)
|
||||
/// - BF16_BF16 (ck::Tuple<ck::bhalf_t, ck::bhalf_t>)
|
||||
/// - FP32 (float)
|
||||
/// - FP32_FP32 (ck::Tuple<float, float>)
|
||||
/// - FP64 (double)
|
||||
/// - FP8 (ck::f8_t)
|
||||
/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t)
|
||||
/// - I8 (int8_t)
|
||||
/// - I8_I8 (ck::Tuple<int8_t, int8_t>)
|
||||
/// - U8 (uint8_t)
|
||||
template <typename Instance>
|
||||
/// @brief Derives the data type from a device kernel `Instance` type.
|
||||
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
|
||||
// Note: maybe move to types.hpp?
|
||||
template <typename DataTypeFromInstance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
if constexpr(std::is_same_v<DataTypeFromInstance, ck::half_t>)
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
return FP16_FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bhalf_t>)
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
return BF16_BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, float>)
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<float, float>>)
|
||||
return FP32_FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, double>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, double>)
|
||||
return FP64;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::f8_t>)
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bf8_fnuz_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bf8_ocp_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, int8_t>)
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<int8_t, int8_t>>)
|
||||
return I8_I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
else if constexpr(std::is_same_v<DataTypeFromInstance, uint8_t>)
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
report_unsupported_data_type_error<ADataType>();
|
||||
report_unsupported_data_type_error<DataTypeFromInstance>();
|
||||
return FP32; // Unreachable
|
||||
}
|
||||
}
|
||||
@@ -736,4 +733,92 @@ constexpr builder::PipelineScheduler get_pipeline_scheduler()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 4: Helper functions for common structures often used in reflection
|
||||
// ============================================================================
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock)
|
||||
{
|
||||
return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr InputTileTransferInfo
|
||||
conv_traits_a_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock)
|
||||
{
|
||||
return InputTileTransferInfo{
|
||||
.tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, .k1 = _k1},
|
||||
.transfer_params = {.k1 = _k1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr InputTileTransferInfo
|
||||
conv_traits_b_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock)
|
||||
{
|
||||
return InputTileTransferInfo{
|
||||
.tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, .k1 = _k1},
|
||||
.transfer_params = {.k1 = _k1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
|
||||
{
|
||||
return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma,
|
||||
.gemm_n = InstTraits::kNPerWmma,
|
||||
.m_iter = InstTraits::kMRepeat,
|
||||
.n_iter = InstTraits::kNRepeat};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
|
||||
{
|
||||
return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
|
||||
{
|
||||
return OutputTileTransferInfo{
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
|
||||
InstTraits::kCDEThreadClusterLengths[1],
|
||||
InstTraits::kCDEThreadClusterLengths[2],
|
||||
InstTraits::kCDEThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer()
|
||||
{
|
||||
return OutputTileTransferInfo{
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -3,6 +3,18 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Fwd instances
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
// Bwd weight instances
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
|
||||
@@ -62,6 +62,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
|
||||
struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -158,7 +162,9 @@ struct InstanceTraits<
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -175,13 +181,13 @@ struct InstanceTraits<
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kKPerBlock = KPerBlock;
|
||||
static constexpr ck::index_t kABK1 = ABK1;
|
||||
static constexpr ck::index_t kK1 = ABK1;
|
||||
static constexpr ck::index_t kMPerWmma = MPerWmma;
|
||||
static constexpr ck::index_t kNPerWmma = NPerWmma;
|
||||
static constexpr ck::index_t kMRepeat = MRepeat;
|
||||
@@ -195,27 +201,46 @@ struct InstanceTraits<
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector =
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
|
||||
|
||||
@@ -231,7 +256,7 @@ struct InstanceTraits<
|
||||
oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -251,30 +276,30 @@ struct InstanceTraits<
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kKPerBlock; // 18. KPerBlock
|
||||
oss << "," << kABK1; // 19. ABK1
|
||||
oss << "," << kMPerWmma; // 20. MPerWmma
|
||||
oss << "," << kNPerWmma; // 21. NPerWmma
|
||||
oss << "," << kMRepeat; // 22. MRepeat
|
||||
oss << "," << kNRepeat; // 23. NRepeat
|
||||
kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kKPerBlock; // 18. KPerBlock
|
||||
oss << "," << kK1; // 19. ABK1
|
||||
oss << "," << kMPerWmma; // 20. MPerWmma
|
||||
oss << "," << kNPerWmma; // 21. NPerWmma
|
||||
oss << "," << kMRepeat; // 22. MRepeat
|
||||
oss << "," << kNRepeat; // 23. NRepeat
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 27.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 34.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
|
||||
oss << ","
|
||||
|
||||
@@ -59,6 +59,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle;
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
@@ -152,7 +156,10 @@ struct InstanceTraits<
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -169,7 +176,7 @@ struct InstanceTraits<
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
@@ -188,22 +195,36 @@ struct InstanceTraits<
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
@@ -211,6 +232,9 @@ struct InstanceTraits<
|
||||
using ComputeTypeA = ComputeTypeA_;
|
||||
using ComputeTypeB = ComputeTypeB_;
|
||||
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
@@ -220,7 +244,7 @@ struct InstanceTraits<
|
||||
oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -240,30 +264,30 @@ struct InstanceTraits<
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 18. K0PerBlock
|
||||
oss << "," << kK1; // 19. K1
|
||||
oss << "," << kMPerXDL; // 20. MPerXDL
|
||||
oss << "," << kNPerXDL; // 21. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
|
||||
kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 18. K0PerBlock
|
||||
oss << "," << kK1; // 19. K1
|
||||
oss << "," << kMPerXDL; // 20. MPerXDL
|
||||
oss << "," << kNPerXDL; // 21. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 27.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_K1; // 29.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 34.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
|
||||
oss << ","
|
||||
|
||||
@@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag device kernel
|
||||
struct DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -161,7 +166,9 @@ struct InstanceTraits<
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -176,7 +183,7 @@ struct InstanceTraits<
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
@@ -201,12 +208,21 @@ struct InstanceTraits<
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
|
||||
@@ -215,13 +231,26 @@ struct InstanceTraits<
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector =
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
|
||||
|
||||
@@ -237,7 +266,7 @@ struct InstanceTraits<
|
||||
oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -255,30 +284,30 @@ struct InstanceTraits<
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kKPerBlock; // 16. KPerBlock
|
||||
oss << "," << kABK1; // 17. ABK1
|
||||
oss << "," << kMPerWmma; // 18. MPerWmma
|
||||
oss << "," << kNPerWmma; // 19. NPerWmma
|
||||
oss << "," << kMRepeat; // 20. MRepeat
|
||||
oss << "," << kNRepeat; // 21. NRepeat
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kKPerBlock; // 16. KPerBlock
|
||||
oss << "," << kABK1; // 17. ABK1
|
||||
oss << "," << kMPerWmma; // 18. MPerWmma
|
||||
oss << "," << kNPerWmma; // 19. NPerWmma
|
||||
oss << "," << kMRepeat; // 20. MRepeat
|
||||
oss << "," << kNRepeat; // 21. NRepeat
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -160,7 +165,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -175,7 +182,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
@@ -199,26 +206,45 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
|
||||
|
||||
@@ -234,7 +260,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
|
||||
oss << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -252,30 +278,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kKPerBlock; // 16. KPerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kKPerBlock; // 16. KPerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -59,6 +59,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
|
||||
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -148,8 +153,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
false>> // Use false to match with the default value
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag;
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -164,15 +170,15 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
|
||||
static constexpr ck::index_t kK1 = K1;
|
||||
static constexpr ck::index_t kMPerWMMA = MPerWMMA;
|
||||
static constexpr ck::index_t kNPerWMMA = NPerWMMA;
|
||||
static constexpr ck::index_t kMPerWmma = MPerWMMA;
|
||||
static constexpr ck::index_t kNPerWmma = NPerWMMA;
|
||||
static constexpr ck::index_t kMRepeat = MRepeat;
|
||||
static constexpr ck::index_t kNRepeat = NRepeat;
|
||||
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
|
||||
@@ -184,26 +190,43 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector =
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
static constexpr ck::LoopScheduler kLoopSched = LoopSched;
|
||||
static constexpr ck::PipelineVersion kPipelineVer = PipelineVer;
|
||||
|
||||
@@ -216,7 +239,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -234,30 +257,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerWMMA; // 18. MPerWMMA
|
||||
oss << "," << kNPerWMMA; // 19. NPerWMMA
|
||||
oss << "," << kMRepeat; // 20. MRepeat
|
||||
oss << "," << kNRepeat; // 21. NRepeat
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerWmma; // 18. MPerWMMA
|
||||
oss << "," << kNPerWmma; // 19. NPerWMMA
|
||||
oss << "," << kMRepeat; // 20. MRepeat
|
||||
oss << "," << kNRepeat; // 21. NRepeat
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -62,6 +62,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
|
||||
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -156,8 +161,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
MaxTransposeTransferDstScalarPerVector>>
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3";
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag;
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -172,13 +178,13 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kKPerBlock = KPerBlock;
|
||||
static constexpr ck::index_t kABK1 = ABK1;
|
||||
static constexpr ck::index_t kK1 = ABK1;
|
||||
static constexpr ck::index_t kMPerWmma = MPerWmma;
|
||||
static constexpr ck::index_t kNPerWmma = NPerWmma;
|
||||
static constexpr ck::index_t kMRepeat = MRepeat;
|
||||
@@ -196,27 +202,46 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector =
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
|
||||
|
||||
@@ -232,7 +257,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -250,30 +275,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kKPerBlock; // 16. KPerBlock
|
||||
oss << "," << kABK1; // 17. ABK1
|
||||
oss << "," << kMPerWmma; // 18. MPerWmma
|
||||
oss << "," << kNPerWmma; // 19. NPerWmma
|
||||
oss << "," << kMRepeat; // 20. MRepeat
|
||||
oss << "," << kNRepeat; // 21. NRepeat
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kKPerBlock; // 16. KPerBlock
|
||||
oss << "," << kK1; // 17. ABK1
|
||||
oss << "," << kMPerWmma; // 18. MPerWmma
|
||||
oss << "," << kNPerWmma; // 19. NPerWmma
|
||||
oss << "," << kMRepeat; // 20. MRepeat
|
||||
oss << "," << kNRepeat; // 21. NRepeat
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -152,7 +157,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -167,43 +173,63 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
|
||||
static constexpr ck::index_t kK1 = K1;
|
||||
static constexpr ck::index_t kMPerXDL = MPerXDL;
|
||||
static constexpr ck::index_t kNPerXDL = NPerXDL;
|
||||
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
|
||||
static constexpr ck::index_t kK1 = K1;
|
||||
static constexpr ck::index_t kMPerXDL = MPerXDL;
|
||||
static constexpr ck::index_t kNPerXDL = NPerXDL;
|
||||
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
|
||||
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
|
||||
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
@@ -224,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -242,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -150,9 +155,12 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
ComputeTypeA_,
|
||||
ComputeTypeB_>>
|
||||
{
|
||||
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag;
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -167,7 +175,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
@@ -182,28 +190,48 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
@@ -222,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -240,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -79,6 +79,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle device kernel
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
@@ -176,6 +181,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
LoopSched,
|
||||
PipelineVer>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag;
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -259,9 +259,118 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Instance = ckb::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
|
||||
EXPECT_THAT(
|
||||
ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
"│ ├─ Weight Layout: GKYXC\n"
|
||||
"│ ├─ Output Layout: GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 2\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
|
||||
{
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
32, // MPerWMMA
|
||||
32, // NPerXDL
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
4, // NumGroupsToMerge
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector>
|
||||
|
||||
EXPECT_THAT(ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"2D Backward Weight Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
@@ -272,37 +381,146 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Data tile size: 128×128×16\n"
|
||||
" ├─ Struct does not contain optional gemm_padding argument\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Pipeline version: V1\n"
|
||||
" ├─ Pipeline scheduler: DEFAULT\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" └─ Vector access (GMEM write) instruction size: 2"));
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
|
||||
" ├─ Max Transpose dst scalar per vector: 1\n"
|
||||
" └─ Num groups to merge: 4"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
|
||||
{
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
3, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNDHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKZYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNDHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerWmma
|
||||
32, // NPerWmma
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
1, // NummGemmKPrefetchStage
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1, // BlkGemmPipelineVer
|
||||
false>; // BComputeDataType
|
||||
|
||||
EXPECT_THAT(
|
||||
ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"3D Backward Weight Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNDHWC\n"
|
||||
"│ ├─ Weight Layout: GKZYXC\n"
|
||||
"│ ├─ Output Layout: GNDHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 128×128×16\n"
|
||||
" ├─ Struct does not contain optional gemm_padding argument\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V1\n"
|
||||
" ├─ Pipeline scheduler: DEFAULT\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Num gemm k prefetch stage: 1\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString)
|
||||
|
||||
Reference in New Issue
Block a user