mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK TILE] Grouped Conv Explicit Gemm (#3289)
* [CK TILE] Grouped Conv Explicit Gemm * fixes * apply builder fixes
This commit is contained in:
@@ -28,10 +28,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
|
||||
@@ -20,11 +20,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDEElementWise>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
false /*EnableSplitImage*/,
|
||||
false /*ExplicitGemm*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
@@ -106,6 +107,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",0" // ExplicitGemm
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
|
||||
@@ -123,7 +123,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
false /*EnableSplitImage*/,
|
||||
false /*ExplicitGemm*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
@@ -208,6 +209,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",0" // ExplicitGemm
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
|
||||
@@ -734,7 +734,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
false /*EnableSplitImage*/,
|
||||
false /*ExplicitGemm*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
@@ -818,6 +819,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",0" // ExplicitGemm
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
|
||||
@@ -228,10 +228,34 @@ struct BatchedGemmKernel
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr) + batch_offset_C;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
__shared__ char smem_ptr0[GetSmemSize()];
|
||||
|
||||
UniversalGemmKernel::RunGemm(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr1[GetSmemSize()];
|
||||
UniversalGemmKernel::RunGemm2LDS({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr,
|
||||
smem_ptr0,
|
||||
smem_ptr1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
UniversalGemmKernel::RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr,
|
||||
smem_ptr0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -511,7 +511,7 @@ template <typename GroupedConvTraitsType_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardDataKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -556,6 +556,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported C GEMM layout!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -983,7 +984,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
return group_id;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const index_t group_id = FindGroupId(kargs, blockIdX);
|
||||
|
||||
@@ -370,7 +370,7 @@ template <typename GroupedConvTraitsType_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -411,6 +411,9 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -503,22 +506,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const stream_config& s)
|
||||
{
|
||||
return [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
|
||||
// since we require that ConvG % NumGroupsPerBatch == 0.
|
||||
const auto wei_size =
|
||||
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
@@ -588,6 +575,14 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
|
||||
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
@@ -886,61 +881,104 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
static_assert(NumDTensor == 0, "Not supported!");
|
||||
using ExplicitBatchedGemmKernel =
|
||||
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
|
||||
{{kargs.out_ptr},
|
||||
{kargs.in_ptr},
|
||||
{},
|
||||
kargs.wei_ptr,
|
||||
kargs.GemmM,
|
||||
kargs.GemmN,
|
||||
kargs.GemmK,
|
||||
{kargs.GemmM * kargs.GemmBatch},
|
||||
{kargs.GemmN * kargs.GemmBatch},
|
||||
{},
|
||||
kargs.GemmN,
|
||||
kargs.k_batch},
|
||||
kargs.GemmM,
|
||||
kargs.GemmN,
|
||||
kargs.GemmM * kargs.GemmN,
|
||||
kargs.GemmBatch};
|
||||
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
|
||||
}
|
||||
|
||||
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
|
||||
const index_t num_loop = amd_wave_read_first_lane(
|
||||
ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
|
||||
const index_t i_k =
|
||||
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
|
||||
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
// conv_bwd_weight = Out * In = Weight
|
||||
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
CallExplicitGemm(kargs);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
|
||||
const index_t num_loop = amd_wave_read_first_lane(ck_tile::integer_divide_ceil(
|
||||
kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
|
||||
const index_t i_k =
|
||||
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
|
||||
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
// conv_bwd_weight = Out * In = Weight
|
||||
const OutDataType* a_ptr =
|
||||
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
num_loop,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
num_loop,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,6 +490,9 @@ struct GroupedConvolutionForwardKernel
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Not supported!");
|
||||
|
||||
// Helper struct for spatial coordinates
|
||||
struct SpatialCoords
|
||||
@@ -678,6 +681,14 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
|
||||
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
@@ -974,135 +985,189 @@ struct GroupedConvolutionForwardKernel
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
static_assert(NumDTensor == 0, "Not supported!");
|
||||
using ExplicitBatchedGemmKernel =
|
||||
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
|
||||
{{kargs.in_ptr},
|
||||
{kargs.wei_ptr},
|
||||
{},
|
||||
kargs.out_ptr,
|
||||
kargs.GemmM,
|
||||
kargs.GemmN,
|
||||
kargs.GemmK,
|
||||
{kargs.GemmK * kargs.GemmBatch},
|
||||
{kargs.GemmK},
|
||||
{},
|
||||
kargs.GemmBatch * kargs.GemmN,
|
||||
kargs.k_batch},
|
||||
kargs.GemmK,
|
||||
kargs.GemmN * kargs.GemmK,
|
||||
kargs.GemmN,
|
||||
kargs.GemmBatch};
|
||||
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
|
||||
}
|
||||
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// Split-N handling: Get which split this workgroup handles
|
||||
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
|
||||
|
||||
// Calculate batch offset for this split
|
||||
const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
|
||||
|
||||
// Calculate memory offsets for this split
|
||||
const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.input_batch_stride);
|
||||
const long_index_t output_batch_offset =
|
||||
static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.output_batch_stride);
|
||||
|
||||
// Calculate base pointers with group and batch offsets
|
||||
const InDataType* base_a_ptr =
|
||||
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
|
||||
group_offset_b; // No batch offset for weights!
|
||||
OutDataType* base_c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// Apply group offsets to D tensors
|
||||
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d) {
|
||||
using DType = std::tuple_element_t<d, DsDataType>;
|
||||
ds_ptr_with_offsets[d] =
|
||||
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
|
||||
});
|
||||
|
||||
// =====================================================================
|
||||
// Split-image: Map local block to global tile index (if enabled)
|
||||
// =====================================================================
|
||||
const InDataType* a_ptr;
|
||||
OutDataType* c_ptr;
|
||||
index_t i_m = 0;
|
||||
index_t i_n = 0;
|
||||
|
||||
// Pre-calculate block_id (used in both split-image and non-split paths)
|
||||
const index_t block_id = static_cast<index_t>(blockIdX);
|
||||
|
||||
if constexpr(EnableSplitImage)
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
// Add spatial offsets for split-image (constexpr optimization)
|
||||
a_ptr = base_a_ptr + kargs.spatial_offset_in;
|
||||
c_ptr = base_c_ptr + kargs.spatial_offset_out;
|
||||
|
||||
// Find which piece owns this block using binary search
|
||||
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
|
||||
const index_t piece_id =
|
||||
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
|
||||
const auto& piece = kargs.split_image.pieces[piece_id];
|
||||
const auto& split_info = kargs.split_image;
|
||||
|
||||
// Calculate local block ID and tile indices
|
||||
const index_t local_block_id = block_id - piece.block_start;
|
||||
const index_t local_gemm_m =
|
||||
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
|
||||
const auto [local_tile_m, local_tile_n] =
|
||||
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
|
||||
|
||||
// Extract batch and spatial coordinates from local tile
|
||||
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
|
||||
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
|
||||
const index_t local_n = local_m_start / spatial_per_batch;
|
||||
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
|
||||
|
||||
// Convert to local spatial coordinates
|
||||
const auto local_coords =
|
||||
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
|
||||
|
||||
// Convert to global spatial coordinates
|
||||
const index_t global_n = local_n;
|
||||
const index_t global_d = piece.d_start + local_coords.d;
|
||||
const index_t global_h = piece.h_start + local_coords.h;
|
||||
const index_t global_w = piece.w_start + local_coords.w;
|
||||
|
||||
// Convert to global M index
|
||||
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
|
||||
const index_t global_spatial_flat = FlattenSpatial(
|
||||
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
|
||||
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
|
||||
|
||||
// Set tile indices for GEMM operation
|
||||
i_m = amd_wave_read_first_lane(global_m);
|
||||
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
|
||||
CallExplicitGemm(kargs);
|
||||
}
|
||||
else
|
||||
{
|
||||
// No spatial offsets needed for regular path
|
||||
a_ptr = base_a_ptr;
|
||||
c_ptr = base_c_ptr;
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
// No split-image: use standard tile partitioning
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
|
||||
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
}
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// Use global descriptors for all cases
|
||||
const auto& a_desc = kargs.a_grid_desc_m_k;
|
||||
const auto& b_desc = kargs.b_grid_desc_n_k;
|
||||
const auto& c_desc = kargs.c_grid_desc_m_n;
|
||||
// Split-N handling: Get which split this workgroup handles
|
||||
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
// Calculate batch offset for this split
|
||||
const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split);
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
// Calculate memory offsets for this split
|
||||
const long_index_t input_batch_offset =
|
||||
static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.input_batch_stride);
|
||||
const long_index_t output_batch_offset =
|
||||
static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.output_batch_stride);
|
||||
|
||||
// Calculate base pointers with group and batch offsets
|
||||
const InDataType* base_a_ptr =
|
||||
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
|
||||
group_offset_b; // No batch offset for weights!
|
||||
OutDataType* base_c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// Apply group offsets to D tensors
|
||||
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d) {
|
||||
using DType = std::tuple_element_t<d, DsDataType>;
|
||||
ds_ptr_with_offsets[d] = static_cast<const DType*>(kargs.ds_ptr[d]) +
|
||||
group_offset_c + output_batch_offset;
|
||||
});
|
||||
|
||||
// =====================================================================
|
||||
// Split-image: Map local block to global tile index (if enabled)
|
||||
// =====================================================================
|
||||
const InDataType* a_ptr;
|
||||
OutDataType* c_ptr;
|
||||
index_t i_m = 0;
|
||||
index_t i_n = 0;
|
||||
|
||||
// Pre-calculate block_id (used in both split-image and non-split paths)
|
||||
const index_t block_id = static_cast<index_t>(blockIdX);
|
||||
|
||||
if constexpr(EnableSplitImage)
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
// Add spatial offsets for split-image (constexpr optimization)
|
||||
a_ptr = base_a_ptr + kargs.spatial_offset_in;
|
||||
c_ptr = base_c_ptr + kargs.spatial_offset_out;
|
||||
|
||||
// Find which piece owns this block using binary search
|
||||
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
|
||||
const index_t piece_id =
|
||||
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
|
||||
const auto& piece = kargs.split_image.pieces[piece_id];
|
||||
const auto& split_info = kargs.split_image;
|
||||
|
||||
// Calculate local block ID and tile indices
|
||||
const index_t local_block_id = block_id - piece.block_start;
|
||||
const index_t local_gemm_m =
|
||||
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
|
||||
const auto [local_tile_m, local_tile_n] =
|
||||
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
|
||||
|
||||
// Extract batch and spatial coordinates from local tile
|
||||
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
|
||||
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
|
||||
const index_t local_n = local_m_start / spatial_per_batch;
|
||||
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
|
||||
|
||||
// Convert to local spatial coordinates
|
||||
const auto local_coords =
|
||||
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
|
||||
|
||||
// Convert to global spatial coordinates
|
||||
const index_t global_n = local_n;
|
||||
const index_t global_d = piece.d_start + local_coords.d;
|
||||
const index_t global_h = piece.h_start + local_coords.h;
|
||||
const index_t global_w = piece.w_start + local_coords.w;
|
||||
|
||||
// Convert to global M index
|
||||
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
|
||||
const index_t global_spatial_flat = FlattenSpatial(
|
||||
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
|
||||
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
|
||||
|
||||
// Set tile indices for GEMM operation
|
||||
i_m = amd_wave_read_first_lane(global_m);
|
||||
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
|
||||
}
|
||||
else
|
||||
{
|
||||
// No spatial offsets needed for regular path
|
||||
a_ptr = base_a_ptr;
|
||||
c_ptr = base_c_ptr;
|
||||
|
||||
// No split-image: use standard tile partitioning
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
|
||||
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
}
|
||||
|
||||
// Use global descriptors for all cases
|
||||
const auto& a_desc = kargs.a_grid_desc_m_k;
|
||||
const auto& b_desc = kargs.b_grid_desc_n_k;
|
||||
const auto& c_desc = kargs.c_grid_desc_m_n;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
@@ -1110,26 +1175,7 @@ struct GroupedConvolutionForwardKernel
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@ template <index_t NDimSpatial_,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
bool EnableSplitImage_ = false>
|
||||
bool EnableSplitImage_ = false,
|
||||
bool ExplicitGemm_ = false>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -89,8 +90,9 @@ struct GroupedConvTraits
|
||||
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
};
|
||||
// Compile time parameters
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr bool ExplicitGemm = ExplicitGemm_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
|
||||
Reference in New Issue
Block a user