mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Remove the obsolete template parameters.
This commit is contained in:
@@ -21,8 +21,6 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck_tile::index_t NumGroupsToMerge = 1,
|
||||
ck_tile::index_t MPerGroup = 0,
|
||||
ck_tile::index_t NPerGroup = 0,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
@@ -66,9 +64,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
NumGroupsToMerge,
|
||||
MPerGroup,
|
||||
NPerGroup>;
|
||||
NumGroupsToMerge>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
@@ -160,8 +156,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
template <
|
||||
typename InPrecType,
|
||||
ck_tile::index_t NumGroupsToMerge = 1,
|
||||
ck_tile::index_t MPerGroup = 0,
|
||||
ck_tile::index_t NPerGroup = 0,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
@@ -185,9 +179,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType,
|
||||
NumGroupsToMerge,
|
||||
MPerGroup,
|
||||
NPerGroup>(
|
||||
NumGroupsToMerge>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
@@ -196,9 +188,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType,
|
||||
NumGroupsToMerge,
|
||||
MPerGroup,
|
||||
NPerGroup>(
|
||||
NumGroupsToMerge>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
@@ -207,9 +197,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType,
|
||||
NumGroupsToMerge,
|
||||
MPerGroup,
|
||||
NPerGroup>(
|
||||
NumGroupsToMerge>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
@@ -223,8 +211,6 @@ int run(const std::string& in_layout,
|
||||
const std::string& wei_layout,
|
||||
const std::string& out_layout,
|
||||
int num_groups_to_merge,
|
||||
int m_per_group,
|
||||
int n_per_group,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
@@ -234,47 +220,7 @@ int run(const std::string& in_layout,
|
||||
}
|
||||
else if (num_groups_to_merge == 8)
|
||||
{
|
||||
if (m_per_group == 1 && n_per_group == 16)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 1, 16>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 2 && n_per_group == 16)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 2, 16>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 4 && n_per_group == 16)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 4, 16>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 1 && n_per_group == 4)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 1, 4>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 2 && n_per_group == 4)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 2, 4>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 4 && n_per_group == 4)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 4, 4>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 1 && n_per_group == 8)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 1, 8>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 2 && n_per_group == 8)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 2, 8>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if (m_per_group == 4 && n_per_group == 8)
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType, 8, 4, 8>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported MPerGroup and NPerGroup combination for NumGroupsToMerge=8! Supported combinations are (1,16), (2,8), (4,4), (8,2), (16,1).");
|
||||
}
|
||||
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<InPrecType>(in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -293,18 +239,16 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
ck_tile::index_t num_groups_to_merge = arg_parser.get_int("num_groups_to_merge");
|
||||
ck_tile::index_t m_per_group = arg_parser.get_int("m_per_group");
|
||||
ck_tile::index_t n_per_group = arg_parser.get_int("n_per_group");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, num_groups_to_merge, m_per_group, n_per_group, argc, argv);
|
||||
in_layout, wei_layout, out_layout, num_groups_to_merge, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, num_groups_to_merge, m_per_group, n_per_group, argc, argv);
|
||||
in_layout, wei_layout, out_layout, num_groups_to_merge, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -122,9 +122,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("num_groups_to_merge", "1", "Number of groups to merge for grouped convolution")
|
||||
.insert("m_per_group", "0", "Number of elements per grouped block in M-dimension")
|
||||
.insert("n_per_group", "0", "Number of elements per grouped block in N-dimension");
|
||||
.insert("num_groups_to_merge", "1", "Number of groups to merge for grouped convolution");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -10,9 +10,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck_tile::index_t NumGroupsToMerge = 1,
|
||||
ck_tile::index_t MPerGroup = 0,
|
||||
ck_tile::index_t NPerGroup = 0>
|
||||
ck_tile::index_t NumGroupsToMerge = 1>
|
||||
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
@@ -25,9 +23,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumGroupsToMerge,
|
||||
MPerGroup,
|
||||
NPerGroup>(
|
||||
NumGroupsToMerge>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
@@ -46,8 +42,6 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
ck_tile::index_t NumGroupsToMerge,
|
||||
ck_tile::index_t MPerGroup,
|
||||
ck_tile::index_t NPerGroup,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
@@ -153,9 +147,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumGroupsToMerge,
|
||||
MPerGroup,
|
||||
NPerGroup>(args, n_warmup, n_repeat);
|
||||
NumGroupsToMerge>(args, n_warmup, n_repeat);
|
||||
|
||||
weight_dev_buf.FromDevice(weight.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -219,13 +219,6 @@ struct CShuffleEpilogue
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// We haven't yet tested the ColumnMajor case.
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue with NumGroupsToMerge > 1 only supports the Row Major layout");
|
||||
}
|
||||
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -264,33 +257,17 @@ struct CShuffleEpilogue
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
return kMPerBlock * kNPerBlock * sizeof(ODataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
//template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
|
||||
{
|
||||
if constexpr (NumGroupsToMerge == 1)
|
||||
{
|
||||
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
return merged_op<MPerGroup, NPerGroup>(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
//return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
|
||||
template <index_t MPerGroup, index_t NPerGroup, typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
|
||||
@@ -620,45 +620,9 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ZYX = kargs.ZYX;
|
||||
const index_t MPerGroup = ConvK;
|
||||
const index_t NPerGroup = ZYX * ConvC;
|
||||
|
||||
// TODO: Fix this check
|
||||
if (GroupedConvTraitsType_::MPerGroup != MPerGroup)
|
||||
{
|
||||
CK_TILE_ERROR("MPerGroup must be equal to Conv K!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Fix this check
|
||||
if (GroupedConvTraitsType_::NPerGroup != NPerGroup)
|
||||
{
|
||||
CK_TILE_ERROR("NPerGroup must be equal to Conv C * ZYX!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Remove this check when ConvC > 1 is implemented.
|
||||
if (ConvC > 1)
|
||||
{
|
||||
CK_TILE_ERROR("Only Conv C == 1 is supported!");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kargs.NumGroupsPerBatch * ConvC * ZYX > TilePartitioner_::NPerBlock)
|
||||
{
|
||||
CK_TILE_ERROR("NumGroupsPerBatch * Conv C * ZYX must be less or equal to NPerBlock!");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kargs.NumGroupsPerBatch * ConvK > TilePartitioner_::MPerBlock)
|
||||
{
|
||||
CK_TILE_ERROR("NumGroupsToMerge * Conv K must be less or equal to MPerBlock!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// TODO: Should we enforce
|
||||
// - ConvG % NumGroupsToMerge == 0?
|
||||
// - ConvK % NumGroupsToMerge == 0?
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -845,8 +809,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<
|
||||
GroupedConvTraitsType_::MPerGroup,
|
||||
GroupedConvTraitsType_::NPerGroup,
|
||||
decltype(c_block_window),
|
||||
decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
|
||||
@@ -50,9 +50,7 @@ template <index_t NDimSpatial_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
index_t MPerGroup_ = 0,
|
||||
index_t NPerGroup_ = 0>
|
||||
index_t NumGroupsToMerge_ = 1>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -64,8 +62,6 @@ struct GroupedConvTraits
|
||||
|
||||
public:
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t MPerGroup = MPerGroup_;
|
||||
static constexpr index_t NPerGroup = NPerGroup_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
|
||||
Reference in New Issue
Block a user