mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Merge commit '1fbb47ad304566a90a374cef4731f1a257e5e179' into develop
This commit is contained in:
@@ -58,7 +58,7 @@ struct TransformConvFwdToGemm
|
||||
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
|
||||
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
|
||||
c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB threshold
|
||||
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
|
||||
@@ -78,23 +78,21 @@ struct GroupedConvFwdKernelArgs
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
// Create and STORE transformer (for split-image support)
|
||||
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
@@ -106,13 +104,16 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
// Initialize Split-N support fields for 1D convolution (NWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_per_split = transformer_.GetN();
|
||||
original_n = transformer_.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NWGC layout
|
||||
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];
|
||||
// Calculate batch strides using the original argument dimensions.
|
||||
// These are the original dimensions passed to the constructor, not modified by the invoker
|
||||
// yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
|
||||
// included - NWGC layout has all groups within each batch
|
||||
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0];
|
||||
@@ -169,23 +170,21 @@ struct GroupedConvFwdKernelArgs
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
// Create and STORE transformer (for split-image support)
|
||||
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
@@ -197,15 +196,16 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_per_split = transformer_.GetN();
|
||||
original_n = transformer_.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NHWGC layout
|
||||
// VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
|
||||
input_batch_stride =
|
||||
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
|
||||
args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
|
||||
output_batch_stride =
|
||||
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
@@ -270,23 +270,21 @@ struct GroupedConvFwdKernelArgs
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
// Create and STORE transformer (for split-image support)
|
||||
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
@@ -298,14 +296,15 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_per_split = transformer_.GetN();
|
||||
original_n = transformer_.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NDHWGC layout
|
||||
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
|
||||
// VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
|
||||
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
|
||||
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
|
||||
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
|
||||
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
|
||||
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
@@ -359,6 +358,42 @@ struct GroupedConvFwdKernelArgs
|
||||
index_t original_n = 1; // Original batch size before splitting
|
||||
index_t input_batch_stride = 0; // Stride to next batch in input tensor
|
||||
index_t output_batch_stride = 0; // Stride to next batch in output tensor
|
||||
|
||||
// Split-image support - spatial offsets (applied per-batch in operator())
|
||||
long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
|
||||
long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
|
||||
|
||||
// Split-image support - transformer instance
|
||||
ConvToGemmFwdTransformer transformer_;
|
||||
|
||||
// Forward declare descriptor types (will be defined after using declarations)
|
||||
using ConvToGemmFwdTransformer_t = ConvToGemmFwdTransformer;
|
||||
using AGridDescMK_t = AGridDescMK;
|
||||
using CGridDescMN_t = CGridDescMN;
|
||||
|
||||
// Split-image support: Common data for all pieces
|
||||
struct SplitImageInfo
|
||||
{
|
||||
// Common dimensions (same for all pieces)
|
||||
index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
|
||||
index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
|
||||
index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
|
||||
|
||||
// Minimal per-piece data (only unique values)
|
||||
struct PieceInfo
|
||||
{
|
||||
index_t block_start; // Starting block index for this piece
|
||||
index_t block_end; // Ending block index (exclusive)
|
||||
index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
|
||||
index_t d_size, h_size, w_size; // Piece size in OUTPUT space
|
||||
};
|
||||
|
||||
static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
|
||||
std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
|
||||
};
|
||||
|
||||
index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
|
||||
SplitImageInfo split_image; // Nested structure with common + per-piece data
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
@@ -399,13 +434,15 @@ struct GroupedConvFwdKernelArgs
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType_,
|
||||
template <bool EnableSplitImage_,
|
||||
typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -435,7 +472,6 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = false;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -449,6 +485,77 @@ struct GroupedConvolutionForwardKernel
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
// Helper struct for spatial coordinates
|
||||
struct SpatialCoords
|
||||
{
|
||||
index_t d, h, w;
|
||||
};
|
||||
|
||||
// Helper: Convert flat spatial index to (d,h,w) coordinates
|
||||
CK_TILE_DEVICE static SpatialCoords
|
||||
UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
return SpatialCoords{0, 0, flat};
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
return SpatialCoords{0, flat / w_size, flat % w_size};
|
||||
}
|
||||
else // NDimSpatial == 3
|
||||
{
|
||||
const index_t hw = h_size * w_size;
|
||||
const index_t d = flat / hw;
|
||||
const index_t remainder = flat % hw;
|
||||
return SpatialCoords{d, remainder / w_size, remainder % w_size};
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: Convert (d,h,w) to flat spatial index
|
||||
CK_TILE_DEVICE static index_t
|
||||
FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
return w;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
return h * total_w + w;
|
||||
}
|
||||
else // NDimSpatial == 3
|
||||
{
|
||||
return (d * total_h + h) * total_w + w;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: Find which piece owns a block using binary search
|
||||
template <typename SplitImageInfo>
|
||||
CK_TILE_DEVICE static index_t
|
||||
FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
|
||||
{
|
||||
index_t left = 0;
|
||||
index_t right = num_pieces - 1;
|
||||
index_t piece_id = (left + right) / 2;
|
||||
|
||||
while(!(block_id >= split_info.pieces[piece_id].block_start &&
|
||||
block_id < split_info.pieces[piece_id].block_end) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_id < split_info.pieces[piece_id].block_start)
|
||||
{
|
||||
right = piece_id - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = piece_id + 1;
|
||||
}
|
||||
piece_id = (left + right) / 2;
|
||||
}
|
||||
return piece_id;
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -475,7 +582,8 @@ struct GroupedConvolutionForwardKernel
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -499,17 +607,6 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
}
|
||||
|
||||
// Check Split-K and Split-N conflict (both use blockIdx.z)
|
||||
if(kargs.k_batch > 1 && kargs.n_splits > 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
@@ -618,27 +715,32 @@ struct GroupedConvolutionForwardKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
typename ADescType,
|
||||
typename BDescType,
|
||||
typename CDescType>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
@@ -651,7 +753,7 @@ struct GroupedConvolutionForwardKernel
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -743,31 +845,39 @@ struct GroupedConvolutionForwardKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param ds_ptr input D tensors pointer array
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param a_desc Input tensor A descriptor
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <typename ADescType, typename BDescType, typename CDescType>
|
||||
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -780,9 +890,8 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{kargs.elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -792,32 +901,40 @@ struct GroupedConvolutionForwardKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param ds_ptr input D tensors pointer array
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param a_desc Input tensor A descriptor
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <typename ADescType, typename BDescType, typename CDescType>
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -837,12 +954,8 @@ struct GroupedConvolutionForwardKernel
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
@@ -860,14 +973,89 @@ struct GroupedConvolutionForwardKernel
|
||||
static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.output_batch_stride);
|
||||
|
||||
// Adjust pointers: combine group offset and batch offset
|
||||
const InDataType* a_ptr =
|
||||
// Calculate base pointers with group and batch offsets
|
||||
const InDataType* base_a_ptr =
|
||||
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
|
||||
group_offset_b; // No batch offset for weights!
|
||||
OutDataType* c_ptr =
|
||||
OutDataType* base_c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// =====================================================================
|
||||
// Split-image: Map local block to global tile index (if enabled)
|
||||
// =====================================================================
|
||||
const InDataType* a_ptr;
|
||||
OutDataType* c_ptr;
|
||||
index_t i_m = 0;
|
||||
index_t i_n = 0;
|
||||
|
||||
// Pre-calculate block_id (used in both split-image and non-split paths)
|
||||
const index_t block_id = static_cast<index_t>(blockIdX);
|
||||
|
||||
if constexpr(EnableSplitImage)
|
||||
{
|
||||
// Add spatial offsets for split-image (constexpr optimization)
|
||||
a_ptr = base_a_ptr + kargs.spatial_offset_in;
|
||||
c_ptr = base_c_ptr + kargs.spatial_offset_out;
|
||||
|
||||
// Find which piece owns this block using binary search
|
||||
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
|
||||
const index_t piece_id =
|
||||
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
|
||||
const auto& piece = kargs.split_image.pieces[piece_id];
|
||||
const auto& split_info = kargs.split_image;
|
||||
|
||||
// Calculate local block ID and tile indices
|
||||
const index_t local_block_id = block_id - piece.block_start;
|
||||
const index_t local_gemm_m =
|
||||
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
|
||||
const auto [local_tile_m, local_tile_n] =
|
||||
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
|
||||
|
||||
// Extract batch and spatial coordinates from local tile
|
||||
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
|
||||
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
|
||||
const index_t local_n = local_m_start / spatial_per_batch;
|
||||
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
|
||||
|
||||
// Convert to local spatial coordinates
|
||||
const auto local_coords =
|
||||
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
|
||||
|
||||
// Convert to global spatial coordinates
|
||||
const index_t global_n = local_n;
|
||||
const index_t global_d = piece.d_start + local_coords.d;
|
||||
const index_t global_h = piece.h_start + local_coords.h;
|
||||
const index_t global_w = piece.w_start + local_coords.w;
|
||||
|
||||
// Convert to global M index
|
||||
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
|
||||
const index_t global_spatial_flat = FlattenSpatial(
|
||||
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
|
||||
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
|
||||
|
||||
// Set tile indices for GEMM operation
|
||||
i_m = amd_wave_read_first_lane(global_m);
|
||||
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
|
||||
}
|
||||
else
|
||||
{
|
||||
// No spatial offsets needed for regular path
|
||||
a_ptr = base_a_ptr;
|
||||
c_ptr = base_c_ptr;
|
||||
|
||||
// No split-image: use standard tile partitioning
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
|
||||
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
}
|
||||
|
||||
// Use global descriptors for all cases
|
||||
const auto& a_desc = kargs.a_grid_desc_m_k;
|
||||
const auto& b_desc = kargs.b_grid_desc_n_k;
|
||||
const auto& c_desc = kargs.c_grid_desc_m_n;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
@@ -878,8 +1066,18 @@ struct GroupedConvolutionForwardKernel
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -888,7 +1086,17 @@ struct GroupedConvolutionForwardKernel
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,4 +110,86 @@ struct GroupedConvTraits
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
/// @brief Helper struct for split-image piece information
|
||||
///
|
||||
/// @par Overview
|
||||
/// Stores metadata for a single spatial piece in split-image convolution.
|
||||
/// Used to track block ranges and spatial coordinates for each piece.
|
||||
struct SplitImagePieceInfo
|
||||
{
|
||||
ck_tile::index_t block_start, block_end; ///< GPU block range for this piece
|
||||
ck_tile::index_t d_start, h_start, w_start; ///< Spatial start coordinates (output space)
|
||||
ck_tile::index_t d_size, h_size, w_size; ///< Spatial dimensions of this piece
|
||||
};
|
||||
|
||||
/// @brief Calculate piece information for split-image convolution
|
||||
///
|
||||
/// @par Overview
|
||||
/// Computes spatial coordinates, dimensions, and GPU block range for a single
|
||||
/// piece in split-image convolution. Handles edge pieces that may have different
|
||||
/// sizes due to non-uniform division.
|
||||
///
|
||||
/// @tparam TilePartitioner Type providing MPerBlock and NPerBlock constants
|
||||
///
|
||||
/// @param piece_idx Index of the piece to calculate (0-based)
|
||||
/// @param num_d_pieces Number of pieces in D dimension
|
||||
/// @param num_h_pieces Number of pieces in H dimension
|
||||
/// @param num_w_pieces Number of pieces in W dimension
|
||||
/// @param base_piece_d Base size of each D piece (may differ for last piece)
|
||||
/// @param base_piece_h Base size of each H piece (may differ for last piece)
|
||||
/// @param base_piece_w Base size of each W piece (may differ for last piece)
|
||||
/// @param total_d Total D dimension size (output space)
|
||||
/// @param total_h Total H dimension size (output space)
|
||||
/// @param total_w Total W dimension size (output space)
|
||||
/// @param N Batch size
|
||||
/// @param K Output channels
|
||||
/// @param total_blocks Accumulated block count from previous pieces
|
||||
///
|
||||
/// @return SplitImagePieceInfo containing all metadata for this piece
|
||||
template <typename TilePartitioner>
|
||||
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx,
|
||||
ck_tile::index_t num_d_pieces,
|
||||
ck_tile::index_t num_h_pieces,
|
||||
ck_tile::index_t num_w_pieces,
|
||||
ck_tile::index_t base_piece_d,
|
||||
ck_tile::index_t base_piece_h,
|
||||
ck_tile::index_t base_piece_w,
|
||||
ck_tile::index_t total_d,
|
||||
ck_tile::index_t total_h,
|
||||
ck_tile::index_t total_w,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t total_blocks)
|
||||
{
|
||||
// Unflatten piece index into 3D coordinates (W-major, then H, then D)
|
||||
const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
|
||||
const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
|
||||
const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
|
||||
|
||||
// Calculate spatial start positions
|
||||
const ck_tile::index_t w_start = w_idx * base_piece_w;
|
||||
const ck_tile::index_t h_start = h_idx * base_piece_h;
|
||||
const ck_tile::index_t d_start = d_idx * base_piece_d;
|
||||
|
||||
// Calculate piece sizes (last piece may be larger to cover remainder)
|
||||
const ck_tile::index_t w_size =
|
||||
(w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
|
||||
const ck_tile::index_t h_size =
|
||||
(h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
|
||||
const ck_tile::index_t d_size =
|
||||
(d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
|
||||
|
||||
// Calculate GEMM dimensions for this piece
|
||||
const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
|
||||
const ck_tile::index_t piece_gemm_n = K;
|
||||
|
||||
// Calculate GPU grid size for this piece
|
||||
const ck_tile::index_t piece_grid =
|
||||
((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
|
||||
((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
|
||||
|
||||
return {
|
||||
total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,9 +5,15 @@
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// Split-Image Information Structure
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// This structure holds all information needed to perform split-image
|
||||
// NOTE: SplitImageInfo struct deleted - was only used by deleted recursive split code
|
||||
// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvSpecialization,
|
||||
index_t VectorSizeA,
|
||||
@@ -28,6 +34,9 @@ struct TransformConvFwdToGemm
|
||||
static constexpr auto I4 = number<4>{};
|
||||
static constexpr auto I5 = number<5>{};
|
||||
|
||||
// Unified memory limit constant for both Split-N and Split-Image
|
||||
static constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
|
||||
const ConvDimsType& strides,
|
||||
@@ -47,6 +56,7 @@ struct TransformConvFwdToGemm
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths)
|
||||
{
|
||||
|
||||
// Calculate strides internally assuming contiguous memory layout
|
||||
ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
|
||||
const index_t num_dims = a_g_n_c_wis_lengths.size();
|
||||
@@ -71,7 +81,6 @@ struct TransformConvFwdToGemm
|
||||
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
|
||||
const long_index_t element_space_size = ck_tile::max(
|
||||
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
|
||||
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
@@ -111,6 +120,145 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Structure to hold split-image decision and factors
|
||||
struct SplitImageInfo
|
||||
{
|
||||
bool should_split;
|
||||
index_t num_d_pieces;
|
||||
index_t num_h_pieces;
|
||||
index_t num_w_pieces;
|
||||
};
|
||||
|
||||
// Calculate split-image factors AFTER considering split-N
|
||||
// Returns: should_split flag and optimal split factors for D, H, W dimensions
|
||||
// Strategy: Hierarchical splitting with priority order D → H → W
|
||||
// Dynamically increases split factors until memory fits below threshold
|
||||
//
|
||||
// NOTE: Layout validation should be done at the invoker level before calling this function
|
||||
// Split-image only works with specific layouts:
|
||||
// 1D: NWGC (input), GKXC (weight), NWGK (output)
|
||||
// 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
|
||||
// 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
|
||||
CK_TILE_HOST static SplitImageInfo GetSplitImageInfo(
|
||||
index_t G, index_t N, index_t C, index_t K, index_t D_out, index_t H_out, index_t W_out)
|
||||
{
|
||||
SplitImageInfo info{false, 1, 1, 1};
|
||||
|
||||
// Estimate memory (simplified calculation)
|
||||
// Use max of input and output tensor sizes
|
||||
// Cast to long_index_t to prevent overflow during multiplication
|
||||
const long_index_t input_elements =
|
||||
static_cast<long_index_t>(N) * D_out * H_out * W_out * C * G;
|
||||
const long_index_t output_elements =
|
||||
static_cast<long_index_t>(N) * D_out * H_out * W_out * K * G;
|
||||
const long_index_t input_bytes = input_elements * sizeof(ADataType);
|
||||
const long_index_t output_bytes = output_elements * sizeof(CDataType);
|
||||
const long_index_t max_tensor_bytes =
|
||||
(input_bytes > output_bytes) ? input_bytes : output_bytes;
|
||||
|
||||
// Calculate effective N after split-N (simplified - assume worst case N=1)
|
||||
index_t effective_N = 1;
|
||||
if(max_tensor_bytes > TwoGB && N > 1)
|
||||
{
|
||||
// Split-N will reduce to approximately N=1 per launch
|
||||
effective_N = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
effective_N = N;
|
||||
}
|
||||
|
||||
// Check if split-image is needed
|
||||
auto calc_memory = [&](index_t d_split, index_t h_split, index_t w_split) -> long_index_t {
|
||||
index_t d_piece = D_out / d_split;
|
||||
index_t h_piece = H_out / h_split;
|
||||
index_t w_piece = W_out / w_split;
|
||||
// Cast to long_index_t to prevent overflow
|
||||
return static_cast<long_index_t>(effective_N) * d_piece * h_piece * w_piece * K * G *
|
||||
sizeof(CDataType);
|
||||
};
|
||||
|
||||
// Calculate memory after split-N with no spatial split
|
||||
const long_index_t memory_after_split_n = calc_memory(1, 1, 1);
|
||||
|
||||
// Check if split-image is needed
|
||||
if(memory_after_split_n <= TwoGB)
|
||||
{
|
||||
info.should_split = false;
|
||||
return info;
|
||||
}
|
||||
|
||||
// Split-image is needed - use hierarchical priority: D → H → W
|
||||
info.should_split = true;
|
||||
|
||||
// Hierarchical splitting strategy:
|
||||
// 1D: Split W until below threshold
|
||||
// 2D: Split H first, if still too large then split W
|
||||
// 3D: Split D first, then H, then W
|
||||
|
||||
// IMPORTANT: Maximum 64 pieces total (hardcoded array limit in invoker)
|
||||
constexpr index_t MAX_TOTAL_PIECES = 64;
|
||||
|
||||
// Start with no split
|
||||
info.num_d_pieces = 1;
|
||||
info.num_h_pieces = 1;
|
||||
info.num_w_pieces = 1;
|
||||
|
||||
// Try splitting D first (for 3D)
|
||||
if(D_out > 1)
|
||||
{
|
||||
index_t max_d_split = (D_out < MAX_TOTAL_PIECES) ? D_out : MAX_TOTAL_PIECES;
|
||||
for(index_t d_split = 2; d_split <= max_d_split; d_split++)
|
||||
{
|
||||
info.num_d_pieces = d_split;
|
||||
if(calc_memory(d_split, 1, 1) <= TwoGB)
|
||||
{
|
||||
return info; // D split alone is sufficient
|
||||
}
|
||||
}
|
||||
// D split maxed out, try H next
|
||||
}
|
||||
|
||||
// Try splitting H (for 2D/3D)
|
||||
if(H_out > 1)
|
||||
{
|
||||
index_t max_h_split = MAX_TOTAL_PIECES / info.num_d_pieces;
|
||||
max_h_split = (H_out < max_h_split) ? H_out : max_h_split;
|
||||
|
||||
for(index_t h_split = 2; h_split <= max_h_split; h_split++)
|
||||
{
|
||||
info.num_h_pieces = h_split;
|
||||
if(calc_memory(info.num_d_pieces, h_split, 1) <= TwoGB)
|
||||
{
|
||||
return info; // D+H split is sufficient
|
||||
}
|
||||
}
|
||||
// H split maxed out, try W next
|
||||
}
|
||||
|
||||
// Try splitting W (for 1D/2D/3D)
|
||||
index_t max_w_split = MAX_TOTAL_PIECES / (info.num_d_pieces * info.num_h_pieces);
|
||||
max_w_split = (W_out < max_w_split) ? W_out : max_w_split;
|
||||
|
||||
for(index_t w_split = 2; w_split <= max_w_split; w_split++)
|
||||
{
|
||||
info.num_w_pieces = w_split;
|
||||
if(calc_memory(info.num_d_pieces, info.num_h_pieces, w_split) <= TwoGB)
|
||||
{
|
||||
return info; // D+H+W split is sufficient
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, even maximum split doesn't fit
|
||||
// Use maximum allowed split as best effort (capped at 64 total pieces)
|
||||
info.num_d_pieces = (D_out < 4) ? D_out : 4; // Cap at 4
|
||||
info.num_h_pieces = (H_out < 4) ? H_out : 4; // Cap at 4
|
||||
info.num_w_pieces = (W_out < 4) ? W_out : 4; // Cap at 4 (max 4×4×4=64)
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
public:
|
||||
// Public getter methods for Split-N support
|
||||
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
|
||||
@@ -192,14 +340,14 @@ struct TransformConvFwdToGemm
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
// Store original N and initialize N_
|
||||
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -244,18 +392,13 @@ struct TransformConvFwdToGemm
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
// Store original N
|
||||
original_N_ = c_g_n_k_wos_lengths[I1];
|
||||
// Store original N and initialize N_
|
||||
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
original_N_ = N_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -300,136 +443,26 @@ struct TransformConvFwdToGemm
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
// Store original N before potential splitting
|
||||
original_N_ = c_g_n_k_wos_lengths[I1];
|
||||
// Store original N and initialize N_
|
||||
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = original_N_;
|
||||
}
|
||||
}
|
||||
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
__host__ bool AreDescriptorsSmallerThan2GB() const
|
||||
// Check if descriptors fit within memory threshold
|
||||
// NOTE: Not currently used - split-image uses different approach in invoker
|
||||
CK_TILE_HOST bool AreDescriptorsSmallerThan2GB() const
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
const long_index_t input_size = static_cast<long_index_t>(N_) * Di_ * Hi_ * Wi_ * C_;
|
||||
const long_index_t output_size = static_cast<long_index_t>(N_) * Do_ * Ho_ * Wo_ * K_;
|
||||
|
||||
const long_index_t in_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
|
||||
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
|
||||
const long_index_t out_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
|
||||
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
|
||||
|
||||
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
|
||||
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
|
||||
|
||||
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
|
||||
const long_index_t threshold = TwoGB / sizeof(ADataType);
|
||||
return (input_size < threshold) && (output_size < threshold);
|
||||
}
|
||||
|
||||
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
|
||||
CDataType* c_grid_ptr_base) const
|
||||
{
|
||||
// Create copies
|
||||
auto conv_to_gemm_transformer_left = *this;
|
||||
auto conv_to_gemm_transformer_right = *this;
|
||||
IndexType a_right_offset = 0;
|
||||
IndexType c_right_offset = 0;
|
||||
// Calculate real filter size
|
||||
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
|
||||
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
|
||||
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
|
||||
// Calculate start position in input for right tensor
|
||||
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
|
||||
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
|
||||
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
|
||||
// Calculate last position in input for left tensor
|
||||
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
|
||||
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
|
||||
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
|
||||
// Allow to split if whole left padding will be in left tensor and right padding in right
|
||||
// tensor
|
||||
const bool is_possible_to_split_d = Do_ != 1 &&
|
||||
di_right_transformer_start_idx > InLeftPadD_ &&
|
||||
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
|
||||
const bool is_possible_to_split_h = Ho_ != 1 &&
|
||||
hi_right_transformer_start_idx > InLeftPadH_ &&
|
||||
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
|
||||
const bool is_possible_to_split_w = Wo_ != 1 &&
|
||||
wi_right_transformer_start_idx > InLeftPadW_ &&
|
||||
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
|
||||
|
||||
if(is_possible_to_split_d)
|
||||
{
|
||||
// Apply new sizes
|
||||
// Split output on half
|
||||
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
|
||||
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
|
||||
// Assign left padding to left convolution
|
||||
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
|
||||
// Assign right padding to right convolution
|
||||
conv_to_gemm_transformer_left.InRightPadD_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
|
||||
// Calculate new input size
|
||||
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.Di_ =
|
||||
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
|
||||
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
|
||||
;
|
||||
// Calcualte offsets
|
||||
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
|
||||
c_right_offset = (Do_ / 2) * DoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_h)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
|
||||
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadH_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
|
||||
|
||||
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.Hi_ =
|
||||
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
|
||||
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
|
||||
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
|
||||
c_right_offset = (Ho_ / 2) * HoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_w)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
|
||||
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadW_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
|
||||
|
||||
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.Wi_ =
|
||||
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
|
||||
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
|
||||
|
||||
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
|
||||
c_right_offset = (Wo_ / 2) * WoStride_;
|
||||
}
|
||||
// Return left transform, right transformer, right offset to Input and right offset to
|
||||
// Output
|
||||
return ck_tile::make_tuple(conv_to_gemm_transformer_left,
|
||||
conv_to_gemm_transformer_right,
|
||||
a_grid_ptr_base + a_right_offset,
|
||||
c_grid_ptr_base + c_right_offset);
|
||||
}
|
||||
#endif
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
template <typename ALayout,
|
||||
@@ -1510,6 +1543,18 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// Split-Image Calculation (AFTER Split-N)
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// This method calculates split-image information using N_ (after Split-N).
|
||||
// This ensures correct offset calculations when both Split-N and Split-Image
|
||||
// are active simultaneously.
|
||||
|
||||
// NOTE: Deleted CalculateSplitImage() and LaunchWithRecursiveSplit() - dead code
|
||||
// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
|
||||
|
||||
public:
|
||||
private:
|
||||
IndexType G_, N_, original_N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
|
||||
Reference in New Issue
Block a user