[CK TILE] Grouped Conv Explicit Gemm (#3289)

* [CK TILE] Grouped Conv Explicit Gemm

* fixes

* apply builder fixes
This commit is contained in:
Bartłomiej Kocot
2025-11-25 23:28:35 +01:00
committed by GitHub
parent 37ea160088
commit 00dfa2f2ce
13 changed files with 386 additions and 269 deletions

View File

@@ -28,10 +28,6 @@ struct GroupedConvolutionForwardInvoker
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,

View File

@@ -20,11 +20,6 @@ struct GroupedConvolutionForwardInvoker
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDEElementWise>& args,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,

View File

@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// Explicit GEMM
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
// TilePartitioner
// Block configuration
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << "," << kExplicitGemm; // 12. ExplicitGemm
oss << "," << kMPerBlock; // 13. MPerBlock
oss << "," << kNPerBlock; // 14. NPerBlock
oss << "," << kKPerBlock; // 15. KPerBlock
oss << "," << kMWarp; // 16. MWarp
oss << "," << kNWarp; // 17. NWarp
oss << "," << kKWarp; // 18. KWarp
oss << "," << kMWarpTile; // 19. MWarpTile
oss << "," << kNWarpTile; // 20. NWarpTile
oss << "," << kKWarpTile; // 21. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << ">";

View File

@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// Explicit GEMM
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
// TilePartitioner
// Block configuration
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << "," << kExplicitGemm; // 12. ExplicitGemm
oss << "," << kMPerBlock; // 13. MPerBlock
oss << "," << kNPerBlock; // 14. NPerBlock
oss << "," << kKPerBlock; // 15. KPerBlock
oss << "," << kMWarp; // 16. MWarp
oss << "," << kNWarp; // 17. NWarp
oss << "," << kKWarp; // 18. KWarp
oss << "," << kMWarpTile; // 19. MWarpTile
oss << "," << kNWarpTile; // 20. NWarpTile
oss << "," << kKWarpTile; // 21. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << ">";

View File

@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// Explicit GEMM
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
// TilePartitioner
// Block configuration
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << "," << kExplicitGemm; // 12. ExplicitGemm
oss << "," << kMPerBlock; // 13. MPerBlock
oss << "," << kNPerBlock; // 14. NPerBlock
oss << "," << kKPerBlock; // 15. KPerBlock
oss << "," << kMWarp; // 16. MWarp
oss << "," << kNWarp; // 17. NWarp
oss << "," << kKWarp; // 18. KWarp
oss << "," << kMWarpTile; // 19. MWarpTile
oss << "," << kNWarpTile; // 20. NWarpTile
oss << "," << kKWarpTile; // 21. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << ">";

View File

@@ -21,7 +21,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
@@ -106,6 +107,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock

View File

@@ -123,7 +123,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
@@ -208,6 +209,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock

View File

@@ -734,7 +734,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
@@ -818,6 +819,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock

View File

@@ -228,10 +228,34 @@ struct BatchedGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr) + batch_offset_C;
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr0[GetSmemSize()];
UniversalGemmKernel::RunGemm(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr1[GetSmemSize()];
UniversalGemmKernel::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr0,
smem_ptr1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
UniversalGemmKernel::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
};

View File

@@ -511,7 +511,7 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardDataKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -556,6 +556,7 @@ struct GroupedConvolutionBackwardDataKernel
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported C GEMM layout!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -983,7 +984,7 @@ struct GroupedConvolutionBackwardDataKernel
return group_id;
}
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized& kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const index_t group_id = FindGroupId(kargs, blockIdX);

View File

@@ -370,7 +370,7 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardWeightKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -411,6 +411,9 @@ struct GroupedConvolutionBackwardWeightKernel
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -503,22 +506,6 @@ struct GroupedConvolutionBackwardWeightKernel
index_t splitted_k;
};
CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const stream_config& s)
{
return [&]() {
if(kargs.k_batch > 1)
{
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
// since we require that ConvG % NumGroupsPerBatch == 0.
const auto wei_size =
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
hipGetErrorString(
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
}
};
}
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
@@ -588,6 +575,14 @@ struct GroupedConvolutionBackwardWeightKernel
}
}
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
CK_TILE_ERROR(
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
return false;
}
namespace ctc = tensor_layout::convolution;
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
@@ -886,61 +881,104 @@ struct GroupedConvolutionBackwardWeightKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
static_assert(NumDTensor == 0, "Not supported!");
using ExplicitBatchedGemmKernel =
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
{{kargs.out_ptr},
{kargs.in_ptr},
{},
kargs.wei_ptr,
kargs.GemmM,
kargs.GemmN,
kargs.GemmK,
{kargs.GemmM * kargs.GemmBatch},
{kargs.GemmN * kargs.GemmBatch},
{},
kargs.GemmN,
kargs.k_batch},
kargs.GemmM,
kargs.GemmN,
kargs.GemmM * kargs.GemmN,
kargs.GemmBatch};
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
const index_t num_loop = amd_wave_read_first_lane(
ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
const index_t i_k =
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// options
// conv_bwd_weight = Out * In = Weight
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
CallExplicitGemm(kargs);
}
else
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
const index_t num_loop = amd_wave_read_first_lane(ck_tile::integer_divide_ceil(
kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
const index_t i_k =
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// options
// conv_bwd_weight = Out * In = Weight
const OutDataType* a_ptr =
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
RunGemm2LDS(a_ptr,
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
num_loop,
i_m,
i_n,
i_k);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
num_loop,
i_m,
i_n,
i_k);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
}
}
}
}

View File

@@ -490,6 +490,9 @@ struct GroupedConvolutionForwardKernel
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
// Helper struct for spatial coordinates
struct SpatialCoords
@@ -678,6 +681,14 @@ struct GroupedConvolutionForwardKernel
}
}
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
CK_TILE_ERROR(
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
return false;
}
namespace ctc = tensor_layout::convolution;
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
@@ -974,135 +985,189 @@ struct GroupedConvolutionForwardKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
static_assert(NumDTensor == 0, "Not supported!");
using ExplicitBatchedGemmKernel =
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
{{kargs.in_ptr},
{kargs.wei_ptr},
{},
kargs.out_ptr,
kargs.GemmM,
kargs.GemmN,
kargs.GemmK,
{kargs.GemmK * kargs.GemmBatch},
{kargs.GemmK},
{},
kargs.GemmBatch * kargs.GemmN,
kargs.k_batch},
kargs.GemmK,
kargs.GemmN * kargs.GemmK,
kargs.GemmN,
kargs.GemmBatch};
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// Split-N handling: Get which split this workgroup handles
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
// Calculate batch offset for this split
const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
// Calculate memory offsets for this split
const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
const long_index_t output_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Calculate base pointers with group and batch offsets
const InDataType* base_a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// Apply group offsets to D tensors
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
static_for<0, NumDTensor, 1>{}([&](auto d) {
using DType = std::tuple_element_t<d, DsDataType>;
ds_ptr_with_offsets[d] =
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
});
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
const InDataType* a_ptr;
OutDataType* c_ptr;
index_t i_m = 0;
index_t i_n = 0;
// Pre-calculate block_id (used in both split-image and non-split paths)
const index_t block_id = static_cast<index_t>(blockIdX);
if constexpr(EnableSplitImage)
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
// Add spatial offsets for split-image (constexpr optimization)
a_ptr = base_a_ptr + kargs.spatial_offset_in;
c_ptr = base_c_ptr + kargs.spatial_offset_out;
// Find which piece owns this block using binary search
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
const index_t piece_id =
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
const auto& piece = kargs.split_image.pieces[piece_id];
const auto& split_info = kargs.split_image;
// Calculate local block ID and tile indices
const index_t local_block_id = block_id - piece.block_start;
const index_t local_gemm_m =
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
const auto [local_tile_m, local_tile_n] =
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
// Extract batch and spatial coordinates from local tile
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
const index_t local_n = local_m_start / spatial_per_batch;
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
// Convert to local spatial coordinates
const auto local_coords =
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
// Convert to global spatial coordinates
const index_t global_n = local_n;
const index_t global_d = piece.d_start + local_coords.d;
const index_t global_h = piece.h_start + local_coords.h;
const index_t global_w = piece.w_start + local_coords.w;
// Convert to global M index
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
const index_t global_spatial_flat = FlattenSpatial(
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
// Set tile indices for GEMM operation
i_m = amd_wave_read_first_lane(global_m);
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
CallExplicitGemm(kargs);
}
else
{
// No spatial offsets needed for regular path
a_ptr = base_a_ptr;
c_ptr = base_c_ptr;
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
// No split-image: use standard tile partitioning
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
}
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// Use global descriptors for all cases
const auto& a_desc = kargs.a_grid_desc_m_k;
const auto& b_desc = kargs.b_grid_desc_n_k;
const auto& c_desc = kargs.c_grid_desc_m_n;
// Split-N handling: Get which split this workgroup handles
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// Calculate batch offset for this split
const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
// Calculate memory offsets for this split
const long_index_t input_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
const long_index_t output_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Calculate base pointers with group and batch offsets
const InDataType* base_a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// Apply group offsets to D tensors
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
static_for<0, NumDTensor, 1>{}([&](auto d) {
using DType = std::tuple_element_t<d, DsDataType>;
ds_ptr_with_offsets[d] = static_cast<const DType*>(kargs.ds_ptr[d]) +
group_offset_c + output_batch_offset;
});
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
const InDataType* a_ptr;
OutDataType* c_ptr;
index_t i_m = 0;
index_t i_n = 0;
// Pre-calculate block_id (used in both split-image and non-split paths)
const index_t block_id = static_cast<index_t>(blockIdX);
if constexpr(EnableSplitImage)
{
RunGemm2LDS(a_ptr,
// Add spatial offsets for split-image (constexpr optimization)
a_ptr = base_a_ptr + kargs.spatial_offset_in;
c_ptr = base_c_ptr + kargs.spatial_offset_out;
// Find which piece owns this block using binary search
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
const index_t piece_id =
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
const auto& piece = kargs.split_image.pieces[piece_id];
const auto& split_info = kargs.split_image;
// Calculate local block ID and tile indices
const index_t local_block_id = block_id - piece.block_start;
const index_t local_gemm_m =
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
const auto [local_tile_m, local_tile_n] =
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
// Extract batch and spatial coordinates from local tile
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
const index_t local_n = local_m_start / spatial_per_batch;
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
// Convert to local spatial coordinates
const auto local_coords =
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
// Convert to global spatial coordinates
const index_t global_n = local_n;
const index_t global_d = piece.d_start + local_coords.d;
const index_t global_h = piece.h_start + local_coords.h;
const index_t global_w = piece.w_start + local_coords.w;
// Convert to global M index
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
const index_t global_spatial_flat = FlattenSpatial(
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
// Set tile indices for GEMM operation
i_m = amd_wave_read_first_lane(global_m);
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
}
else
{
// No spatial offsets needed for regular path
a_ptr = base_a_ptr;
c_ptr = base_c_ptr;
// No split-image: use standard tile partitioning
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
}
// Use global descriptors for all cases
const auto& a_desc = kargs.a_grid_desc_m_k;
const auto& b_desc = kargs.b_grid_desc_n_k;
const auto& c_desc = kargs.c_grid_desc_m_n;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n,
kargs.elfunc);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
@@ -1110,26 +1175,7 @@ struct GroupedConvolutionForwardKernel
i_m,
i_n,
kargs.elfunc);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n,
kargs.elfunc);
}
}
}
}

View File

@@ -63,7 +63,8 @@ template <index_t NDimSpatial_,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
bool EnableSplitImage_ = false>
bool EnableSplitImage_ = false,
bool ExplicitGemm_ = false>
struct GroupedConvTraits
{
private:
@@ -89,8 +90,9 @@ struct GroupedConvTraits
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
};
// Compile time parameters
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr bool ExplicitGemm = ExplicitGemm_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;