[CK TILE] Grouped Conv Explicit Gemm (#3289)

* [CK TILE] Grouped Conv Explicit Gemm

* fixes

* apply builder fixes
This commit is contained in:
Bartłomiej Kocot
2025-11-25 23:28:35 +01:00
committed by GitHub
parent 37ea160088
commit 00dfa2f2ce
13 changed files with 386 additions and 269 deletions

View File

@@ -511,7 +511,7 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardDataKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -556,6 +556,7 @@ struct GroupedConvolutionBackwardDataKernel
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported C GEMM layout!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -983,7 +984,7 @@ struct GroupedConvolutionBackwardDataKernel
return group_id;
}
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized& kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const index_t group_id = FindGroupId(kargs, blockIdX);

View File

@@ -370,7 +370,7 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardWeightKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -411,6 +411,9 @@ struct GroupedConvolutionBackwardWeightKernel
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -503,22 +506,6 @@ struct GroupedConvolutionBackwardWeightKernel
index_t splitted_k;
};
CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const stream_config& s)
{
return [&]() {
if(kargs.k_batch > 1)
{
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
// since we require that ConvG % NumGroupsPerBatch == 0.
const auto wei_size =
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
hipGetErrorString(
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
}
};
}
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
@@ -588,6 +575,14 @@ struct GroupedConvolutionBackwardWeightKernel
}
}
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
CK_TILE_ERROR(
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
return false;
}
namespace ctc = tensor_layout::convolution;
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
@@ -886,61 +881,104 @@ struct GroupedConvolutionBackwardWeightKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
static_assert(NumDTensor == 0, "Not supported!");
using ExplicitBatchedGemmKernel =
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
{{kargs.out_ptr},
{kargs.in_ptr},
{},
kargs.wei_ptr,
kargs.GemmM,
kargs.GemmN,
kargs.GemmK,
{kargs.GemmM * kargs.GemmBatch},
{kargs.GemmN * kargs.GemmBatch},
{},
kargs.GemmN,
kargs.k_batch},
kargs.GemmM,
kargs.GemmN,
kargs.GemmM * kargs.GemmN,
kargs.GemmBatch};
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
const index_t num_loop = amd_wave_read_first_lane(
ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
const index_t i_k =
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// options
// conv_bwd_weight = Out * In = Weight
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
CallExplicitGemm(kargs);
}
else
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
const index_t num_loop = amd_wave_read_first_lane(ck_tile::integer_divide_ceil(
kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
const index_t i_k =
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// options
// conv_bwd_weight = Out * In = Weight
const OutDataType* a_ptr =
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
RunGemm2LDS(a_ptr,
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
num_loop,
i_m,
i_n,
i_k);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
num_loop,
i_m,
i_n,
i_k);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
}
}
}
}

View File

@@ -490,6 +490,9 @@ struct GroupedConvolutionForwardKernel
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
// Helper struct for spatial coordinates
struct SpatialCoords
@@ -678,6 +681,14 @@ struct GroupedConvolutionForwardKernel
}
}
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
CK_TILE_ERROR(
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
return false;
}
namespace ctc = tensor_layout::convolution;
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
@@ -974,135 +985,189 @@ struct GroupedConvolutionForwardKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
static_assert(NumDTensor == 0, "Not supported!");
using ExplicitBatchedGemmKernel =
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
{{kargs.in_ptr},
{kargs.wei_ptr},
{},
kargs.out_ptr,
kargs.GemmM,
kargs.GemmN,
kargs.GemmK,
{kargs.GemmK * kargs.GemmBatch},
{kargs.GemmK},
{},
kargs.GemmBatch * kargs.GemmN,
kargs.k_batch},
kargs.GemmK,
kargs.GemmN * kargs.GemmK,
kargs.GemmN,
kargs.GemmBatch};
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// Split-N handling: Get which split this workgroup handles
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
// Calculate batch offset for this split
const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
// Calculate memory offsets for this split
const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
const long_index_t output_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Calculate base pointers with group and batch offsets
const InDataType* base_a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// Apply group offsets to D tensors
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
static_for<0, NumDTensor, 1>{}([&](auto d) {
using DType = std::tuple_element_t<d, DsDataType>;
ds_ptr_with_offsets[d] =
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
});
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
const InDataType* a_ptr;
OutDataType* c_ptr;
index_t i_m = 0;
index_t i_n = 0;
// Pre-calculate block_id (used in both split-image and non-split paths)
const index_t block_id = static_cast<index_t>(blockIdX);
if constexpr(EnableSplitImage)
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
// Add spatial offsets for split-image (constexpr optimization)
a_ptr = base_a_ptr + kargs.spatial_offset_in;
c_ptr = base_c_ptr + kargs.spatial_offset_out;
// Find which piece owns this block using binary search
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
const index_t piece_id =
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
const auto& piece = kargs.split_image.pieces[piece_id];
const auto& split_info = kargs.split_image;
// Calculate local block ID and tile indices
const index_t local_block_id = block_id - piece.block_start;
const index_t local_gemm_m =
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
const auto [local_tile_m, local_tile_n] =
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
// Extract batch and spatial coordinates from local tile
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
const index_t local_n = local_m_start / spatial_per_batch;
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
// Convert to local spatial coordinates
const auto local_coords =
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
// Convert to global spatial coordinates
const index_t global_n = local_n;
const index_t global_d = piece.d_start + local_coords.d;
const index_t global_h = piece.h_start + local_coords.h;
const index_t global_w = piece.w_start + local_coords.w;
// Convert to global M index
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
const index_t global_spatial_flat = FlattenSpatial(
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
// Set tile indices for GEMM operation
i_m = amd_wave_read_first_lane(global_m);
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
CallExplicitGemm(kargs);
}
else
{
// No spatial offsets needed for regular path
a_ptr = base_a_ptr;
c_ptr = base_c_ptr;
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
// No split-image: use standard tile partitioning
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
}
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// Use global descriptors for all cases
const auto& a_desc = kargs.a_grid_desc_m_k;
const auto& b_desc = kargs.b_grid_desc_n_k;
const auto& c_desc = kargs.c_grid_desc_m_n;
// Split-N handling: Get which split this workgroup handles
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// Calculate batch offset for this split
const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
// Calculate memory offsets for this split
const long_index_t input_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
const long_index_t output_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Calculate base pointers with group and batch offsets
const InDataType* base_a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// Apply group offsets to D tensors
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
static_for<0, NumDTensor, 1>{}([&](auto d) {
using DType = std::tuple_element_t<d, DsDataType>;
ds_ptr_with_offsets[d] = static_cast<const DType*>(kargs.ds_ptr[d]) +
group_offset_c + output_batch_offset;
});
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
const InDataType* a_ptr;
OutDataType* c_ptr;
index_t i_m = 0;
index_t i_n = 0;
// Pre-calculate block_id (used in both split-image and non-split paths)
const index_t block_id = static_cast<index_t>(blockIdX);
if constexpr(EnableSplitImage)
{
RunGemm2LDS(a_ptr,
// Add spatial offsets for split-image (constexpr optimization)
a_ptr = base_a_ptr + kargs.spatial_offset_in;
c_ptr = base_c_ptr + kargs.spatial_offset_out;
// Find which piece owns this block using binary search
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
const index_t piece_id =
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
const auto& piece = kargs.split_image.pieces[piece_id];
const auto& split_info = kargs.split_image;
// Calculate local block ID and tile indices
const index_t local_block_id = block_id - piece.block_start;
const index_t local_gemm_m =
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
const auto [local_tile_m, local_tile_n] =
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
// Extract batch and spatial coordinates from local tile
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
const index_t local_n = local_m_start / spatial_per_batch;
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
// Convert to local spatial coordinates
const auto local_coords =
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
// Convert to global spatial coordinates
const index_t global_n = local_n;
const index_t global_d = piece.d_start + local_coords.d;
const index_t global_h = piece.h_start + local_coords.h;
const index_t global_w = piece.w_start + local_coords.w;
// Convert to global M index
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
const index_t global_spatial_flat = FlattenSpatial(
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
// Set tile indices for GEMM operation
i_m = amd_wave_read_first_lane(global_m);
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
}
else
{
// No spatial offsets needed for regular path
a_ptr = base_a_ptr;
c_ptr = base_c_ptr;
// No split-image: use standard tile partitioning
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
}
// Use global descriptors for all cases
const auto& a_desc = kargs.a_grid_desc_m_k;
const auto& b_desc = kargs.b_grid_desc_n_k;
const auto& c_desc = kargs.c_grid_desc_m_n;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n,
kargs.elfunc);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
@@ -1110,26 +1175,7 @@ struct GroupedConvolutionForwardKernel
i_m,
i_n,
kargs.elfunc);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n,
kargs.elfunc);
}
}
}
}

View File

@@ -63,7 +63,8 @@ template <index_t NDimSpatial_,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
bool EnableSplitImage_ = false>
bool EnableSplitImage_ = false,
bool ExplicitGemm_ = false>
struct GroupedConvTraits
{
private:
@@ -89,8 +90,9 @@ struct GroupedConvTraits
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
};
// Compile time parameters
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr bool ExplicitGemm = ExplicitGemm_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;