Merge commit '1fbb47ad304566a90a374cef4731f1a257e5e179' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-01 13:15:56 +00:00
parent d560ad2092
commit e065ebfb86
8 changed files with 1124 additions and 306 deletions

View File

@@ -78,23 +78,21 @@ struct GroupedConvFwdKernelArgs
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// Create and STORE transformer (for split-image support)
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
a_grid_desc_m_k =
conv_to_gemm_transformer
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
b_grid_desc_n_k =
conv_to_gemm_transformer
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
c_grid_desc_m_n =
conv_to_gemm_transformer
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
@@ -106,13 +104,16 @@ struct GroupedConvFwdKernelArgs
// Initialize Split-N support fields for 1D convolution (NWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_per_split = transformer_.GetN();
original_n = transformer_.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];
// Calculate batch strides using the original argument dimensions.
// These are the original dimensions passed to the constructor, not modified by the invoker
// yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
// included - NWGC layout has all groups within each batch
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0];
@@ -169,23 +170,21 @@ struct GroupedConvFwdKernelArgs
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// Create and STORE transformer (for split-image support)
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
a_grid_desc_m_k =
conv_to_gemm_transformer
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
b_grid_desc_n_k =
conv_to_gemm_transformer
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
c_grid_desc_m_n =
conv_to_gemm_transformer
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
@@ -197,15 +196,16 @@ struct GroupedConvFwdKernelArgs
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_per_split = transformer_.GetN();
original_n = transformer_.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NHWGC layout
// VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
input_batch_stride =
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
output_batch_stride =
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
@@ -270,23 +270,21 @@ struct GroupedConvFwdKernelArgs
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// Create and STORE transformer (for split-image support)
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
a_grid_desc_m_k =
conv_to_gemm_transformer
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
b_grid_desc_n_k =
conv_to_gemm_transformer
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
c_grid_desc_m_n =
conv_to_gemm_transformer
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
@@ -298,14 +296,15 @@ struct GroupedConvFwdKernelArgs
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_per_split = transformer_.GetN();
original_n = transformer_.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NDHWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
// VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
// Update GemmM to use split N (not original N)
@@ -359,6 +358,42 @@ struct GroupedConvFwdKernelArgs
index_t original_n = 1; // Original batch size before splitting
index_t input_batch_stride = 0; // Stride to next batch in input tensor
index_t output_batch_stride = 0; // Stride to next batch in output tensor
// Split-image support - spatial offsets (applied per-batch in operator())
long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
// Split-image support - transformer instance
ConvToGemmFwdTransformer transformer_;
// Forward declare descriptor types (will be defined after using declarations)
using ConvToGemmFwdTransformer_t = ConvToGemmFwdTransformer;
using AGridDescMK_t = AGridDescMK;
using CGridDescMN_t = CGridDescMN;
// Split-image support: Common data for all pieces
struct SplitImageInfo
{
// Common dimensions (same for all pieces)
index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
// Minimal per-piece data (only unique values)
struct PieceInfo
{
index_t block_start; // Starting block index for this piece
index_t block_end; // Ending block index (exclusive)
index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
index_t d_size, h_size, w_size; // Piece size in OUTPUT space
};
static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
};
index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
SplitImageInfo split_image; // Nested structure with common + per-piece data
};
/// @brief The Grouped Convolution Forward kernel template.
@@ -399,13 +434,15 @@ struct GroupedConvFwdKernelArgs
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output C tensor in global memory.
template <typename GroupedConvTraitsType_,
template <bool EnableSplitImage_,
typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -435,7 +472,6 @@ struct GroupedConvolutionForwardKernel
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
// TODO: Enable this
static constexpr bool IsSplitKSupported = false;
static constexpr auto I0 = number<0>();
@@ -449,6 +485,77 @@ struct GroupedConvolutionForwardKernel
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
// Helper struct for spatial coordinates
struct SpatialCoords
{
index_t d, h, w;
};
// Helper: Convert flat spatial index to (d,h,w) coordinates
CK_TILE_DEVICE static SpatialCoords
UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
{
if constexpr(NDimSpatial == 1)
{
return SpatialCoords{0, 0, flat};
}
else if constexpr(NDimSpatial == 2)
{
return SpatialCoords{0, flat / w_size, flat % w_size};
}
else // NDimSpatial == 3
{
const index_t hw = h_size * w_size;
const index_t d = flat / hw;
const index_t remainder = flat % hw;
return SpatialCoords{d, remainder / w_size, remainder % w_size};
}
}
// Helper: Convert (d,h,w) to flat spatial index
CK_TILE_DEVICE static index_t
FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
{
if constexpr(NDimSpatial == 1)
{
return w;
}
else if constexpr(NDimSpatial == 2)
{
return h * total_w + w;
}
else // NDimSpatial == 3
{
return (d * total_h + h) * total_w + w;
}
}
// Helper: Find which piece owns a block using binary search
template <typename SplitImageInfo>
CK_TILE_DEVICE static index_t
FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
{
index_t left = 0;
index_t right = num_pieces - 1;
index_t piece_id = (left + right) / 2;
while(!(block_id >= split_info.pieces[piece_id].block_start &&
block_id < split_info.pieces[piece_id].block_end) &&
left <= right)
{
if(block_id < split_info.pieces[piece_id].block_start)
{
right = piece_id - 1;
}
else
{
left = piece_id + 1;
}
piece_id = (left + right) / 2;
}
return piece_id;
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -475,7 +582,8 @@ struct GroupedConvolutionForwardKernel
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
{
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
return kargs;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -499,17 +607,6 @@ struct GroupedConvolutionForwardKernel
}
}
// Check Split-K and Split-N conflict (both use blockIdx.z)
if(kargs.k_batch > 1 && kargs.n_splits > 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
}
return false;
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
@@ -618,27 +715,32 @@ struct GroupedConvolutionForwardKernel
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename ADescType,
typename BDescType,
typename CDescType>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
const GroupedConvFwdKernelArgsSpecialized& kargs)
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
}();
const auto& b_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
}();
const auto& ds_tensor_view = generate_tuple(
@@ -651,7 +753,7 @@ struct GroupedConvolutionForwardKernel
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
},
number<NumDTensor>{});
@@ -743,31 +845,39 @@ struct GroupedConvolutionForwardKernel
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param ds_ptr input D tensors pointer array
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs Grouped Convolution Forward kernel arguments
* @param a_desc Input tensor A descriptor
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <typename ADescType, typename BDescType, typename CDescType>
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* smem_ptr_0,
const GroupedConvFwdKernelArgsSpecialized& kargs,
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -780,9 +890,8 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{kargs.elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
@@ -792,32 +901,40 @@ struct GroupedConvolutionForwardKernel
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param ds_ptr input D tensors pointer array
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs Grouped Convolution Forward kernel arguments
* @param a_desc Input tensor A descriptor
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <typename ADescType, typename BDescType, typename CDescType>
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GroupedConvFwdKernelArgsSpecialized& kargs,
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -837,12 +954,8 @@ struct GroupedConvolutionForwardKernel
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
@@ -860,14 +973,89 @@ struct GroupedConvolutionForwardKernel
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Adjust pointers: combine group offset and batch offset
const InDataType* a_ptr =
// Calculate base pointers with group and batch offsets
const InDataType* base_a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* c_ptr =
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
const InDataType* a_ptr;
OutDataType* c_ptr;
index_t i_m = 0;
index_t i_n = 0;
// Pre-calculate block_id (used in both split-image and non-split paths)
const index_t block_id = static_cast<index_t>(blockIdX);
if constexpr(EnableSplitImage)
{
// Add spatial offsets for split-image (constexpr optimization)
a_ptr = base_a_ptr + kargs.spatial_offset_in;
c_ptr = base_c_ptr + kargs.spatial_offset_out;
// Find which piece owns this block using binary search
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
const index_t piece_id =
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
const auto& piece = kargs.split_image.pieces[piece_id];
const auto& split_info = kargs.split_image;
// Calculate local block ID and tile indices
const index_t local_block_id = block_id - piece.block_start;
const index_t local_gemm_m =
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
const auto [local_tile_m, local_tile_n] =
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
// Extract batch and spatial coordinates from local tile
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
const index_t local_n = local_m_start / spatial_per_batch;
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
// Convert to local spatial coordinates
const auto local_coords =
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
// Convert to global spatial coordinates
const index_t global_n = local_n;
const index_t global_d = piece.d_start + local_coords.d;
const index_t global_h = piece.h_start + local_coords.h;
const index_t global_w = piece.w_start + local_coords.w;
// Convert to global M index
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
const index_t global_spatial_flat = FlattenSpatial(
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
// Set tile indices for GEMM operation
i_m = amd_wave_read_first_lane(global_m);
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
}
else
{
// No spatial offsets needed for regular path
a_ptr = base_a_ptr;
c_ptr = base_c_ptr;
// No split-image: use standard tile partitioning
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
}
// Use global descriptors for all cases
const auto& a_desc = kargs.a_grid_desc_m_k;
const auto& b_desc = kargs.b_grid_desc_n_k;
const auto& c_desc = kargs.c_grid_desc_m_n;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
@@ -878,8 +1066,18 @@ struct GroupedConvolutionForwardKernel
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n);
}
}
else
@@ -888,7 +1086,17 @@ struct GroupedConvolutionForwardKernel
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n);
}
}
}