Remove the obsolete template parameters.

This commit is contained in:
Ville Pietilä
2025-10-03 14:36:48 +00:00
parent 99fe3df99a
commit 48d22d2b9b
6 changed files with 18 additions and 149 deletions

View File

@@ -21,8 +21,6 @@ template <ck_tile::index_t NDimSpatial,
typename WeiLayout,
typename OutLayout,
ck_tile::index_t NumGroupsToMerge = 1,
ck_tile::index_t MPerGroup = 0,
ck_tile::index_t NPerGroup = 0,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
@@ -66,9 +64,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
WeiLayout,
DsLayout,
OutLayout,
NumGroupsToMerge,
MPerGroup,
NPerGroup>;
NumGroupsToMerge>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
@@ -160,8 +156,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
template <
typename InPrecType,
ck_tile::index_t NumGroupsToMerge = 1,
ck_tile::index_t MPerGroup = 0,
ck_tile::index_t NPerGroup = 0,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_weight_example_prec_type(
@@ -185,9 +179,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
InPrecType,
WeiPrecType,
OutPrecType,
NumGroupsToMerge,
MPerGroup,
NPerGroup>(
NumGroupsToMerge>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
@@ -196,9 +188,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
InPrecType,
WeiPrecType,
OutPrecType,
NumGroupsToMerge,
MPerGroup,
NPerGroup>(
NumGroupsToMerge>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
@@ -207,9 +197,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
InPrecType,
WeiPrecType,
OutPrecType,
NumGroupsToMerge,
MPerGroup,
NPerGroup>(
NumGroupsToMerge>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
@@ -223,8 +211,6 @@ int run(const std::string& in_layout,
const std::string& wei_layout,
const std::string& out_layout,
int num_groups_to_merge,
int m_per_group,
int n_per_group,
int argc,
char* argv[])
{
@@ -234,47 +220,7 @@ int run(const std::string& in_layout,
}
else if (num_groups_to_merge == 8)
{
if (m_per_group == 1 && n_per_group == 16)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 1, 16>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 2 && n_per_group == 16)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 2, 16>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 4 && n_per_group == 16)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 4, 16>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 1 && n_per_group == 4)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 1, 4>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 2 && n_per_group == 4)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 2, 4>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 4 && n_per_group == 4)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 4, 4>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 1 && n_per_group == 8)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 1, 8>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 2 && n_per_group == 8)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 2, 8>(in_layout, wei_layout, out_layout, argc, argv);
}
else if (m_per_group == 4 && n_per_group == 8)
{
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 4, 8>(in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported MPerGroup and NPerGroup combination for NumGroupsToMerge=8! Supported combinations are (1,16), (2,8), (4,4), (8,2), (16,1).");
}
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType>(in_layout, wei_layout, out_layout, argc, argv);
}
else
{
@@ -293,18 +239,16 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
ck_tile::index_t num_groups_to_merge = arg_parser.get_int("num_groups_to_merge");
ck_tile::index_t m_per_group = arg_parser.get_int("m_per_group");
ck_tile::index_t n_per_group = arg_parser.get_int("n_per_group");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(
in_layout, wei_layout, out_layout, num_groups_to_merge, m_per_group, n_per_group, argc, argv);
in_layout, wei_layout, out_layout, num_groups_to_merge, argc, argv);
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, num_groups_to_merge, m_per_group, n_per_group, argc, argv);
in_layout, wei_layout, out_layout, num_groups_to_merge, argc, argv);
}
else
{

View File

@@ -122,9 +122,7 @@ auto create_args(int argc, char* argv[])
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("num_groups_to_merge", "1", "Number of groups to merge for grouped convolution")
.insert("m_per_group", "0", "Number of elements per grouped block in M-dimension")
.insert("n_per_group", "0", "Number of elements per grouped block in N-dimension");
.insert("num_groups_to_merge", "1", "Number of groups to merge for grouped convolution");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -10,9 +10,7 @@ template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
ck_tile::index_t NumGroupsToMerge = 1,
ck_tile::index_t MPerGroup = 0,
ck_tile::index_t NPerGroup = 0>
ck_tile::index_t NumGroupsToMerge = 1>
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
int n_warmup,
int n_repeat)
@@ -25,9 +23,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
InLayout,
WeiLayout,
OutLayout,
NumGroupsToMerge,
MPerGroup,
NPerGroup>(
NumGroupsToMerge>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
@@ -46,8 +42,6 @@ template <ck_tile::index_t NDimSpatial,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
ck_tile::index_t NumGroupsToMerge,
ck_tile::index_t MPerGroup,
ck_tile::index_t NPerGroup,
typename InLayout,
typename WeiLayout,
typename OutLayout>
@@ -153,9 +147,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
InLayout,
WeiLayout,
OutLayout,
NumGroupsToMerge,
MPerGroup,
NPerGroup>(args, n_warmup, n_repeat);
NumGroupsToMerge>(args, n_warmup, n_repeat);
weight_dev_buf.FromDevice(weight.data());
bool pass = true;

View File

@@ -219,13 +219,6 @@ struct CShuffleEpilogue
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
if constexpr(NumGroupsToMerge > 1)
{
// We haven't yet tested the ColumnMajor case.
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue with NumGroupsToMerge > 1 only supports the Row Major layout");
}
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
@@ -264,33 +257,17 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(NumGroupsToMerge > 1)
{
return kMPerBlock * kNPerBlock * sizeof(ODataType);
}
else
{
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
//template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
if constexpr (NumGroupsToMerge == 1)
{
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
else
{
return merged_op<MPerGroup, NPerGroup>(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
//return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>

View File

@@ -620,45 +620,9 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
if (GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ZYX = kargs.ZYX;
const index_t MPerGroup = ConvK;
const index_t NPerGroup = ZYX * ConvC;
// TODO: Fix this check
if (GroupedConvTraitsType_::MPerGroup != MPerGroup)
{
CK_TILE_ERROR("MPerGroup must be equal to Conv K!");
return false;
}
// TODO: Fix this check
if (GroupedConvTraitsType_::NPerGroup != NPerGroup)
{
CK_TILE_ERROR("NPerGroup must be equal to Conv C * ZYX!");
return false;
}
// TODO: Remove this check when ConvC > 1 is implemented.
if (ConvC > 1)
{
CK_TILE_ERROR("Only Conv C == 1 is supported!");
return false;
}
if (kargs.NumGroupsPerBatch * ConvC * ZYX > TilePartitioner_::NPerBlock)
{
CK_TILE_ERROR("NumGroupsPerBatch * Conv C * ZYX must be less or equal to NPerBlock!");
return false;
}
if (kargs.NumGroupsPerBatch * ConvK > TilePartitioner_::MPerBlock)
{
CK_TILE_ERROR("NumGroupsToMerge * Conv K must be less or equal to MPerBlock!");
return false;
}
}
// TODO: Should we enforce
// - ConvG % NumGroupsToMerge == 0?
// - ConvK % NumGroupsToMerge == 0?
return true;
}
@@ -845,8 +809,6 @@ struct GroupedConvolutionBackwardWeightKernel
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<
GroupedConvTraitsType_::MPerGroup,
GroupedConvTraitsType_::NPerGroup,
decltype(c_block_window),
decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);

View File

@@ -50,9 +50,7 @@ template <index_t NDimSpatial_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t NumGroupsToMerge_ = 1,
index_t MPerGroup_ = 0,
index_t NPerGroup_ = 0>
index_t NumGroupsToMerge_ = 1>
struct GroupedConvTraits
{
private:
@@ -64,8 +62,6 @@ struct GroupedConvTraits
public:
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t MPerGroup = MPerGroup_;
static constexpr index_t NPerGroup = NPerGroup_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;