resolve conflicts

This commit is contained in:
Jakub Piasecki
2025-06-16 15:33:24 +00:00
8 changed files with 106 additions and 21 deletions

View File

@@ -19,7 +19,10 @@ template <ck_tile::index_t NDimSpatial,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
@@ -49,7 +52,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, OutLayout>;
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
@@ -68,9 +71,12 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -146,6 +146,7 @@ int run_grouped_conv_fwd_example_with_layouts(
ck_tile::GroupedConvHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);

View File

@@ -24,6 +24,7 @@ struct GroupedConvFwdKernelArgs
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<GroupedConvTraitsType::NDimSpatial,
GroupedConvTraitsType::ConvSpecialization>;
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
template <
typename InLay = typename GroupedConvTraitsType::InLayout,
@@ -61,6 +62,10 @@ struct GroupedConvFwdKernelArgs
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
{
ds_ptr[d] = args.ds_ptr[d];
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
@@ -133,6 +138,10 @@ struct GroupedConvFwdKernelArgs
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
{
ds_ptr[d] = args.ds_ptr[d];
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
@@ -214,6 +223,10 @@ struct GroupedConvFwdKernelArgs
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
{
ds_ptr[d] = args.ds_ptr[d];
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
@@ -270,6 +283,7 @@ struct GroupedConvFwdKernelArgs
const void* in_ptr;
const void* wei_ptr;
std::array<const void*, NumDTensor> ds_ptr;
void* out_ptr;
AGridDescMK a_grid_desc_m_k;
@@ -338,11 +352,17 @@ struct GroupedConvolutionForwardKernel
using InLayout = remove_cvref_t<typename GroupedConvTraitsType::InLayout>;
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType::WeiLayout>;
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType::OutLayout>;
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType::DsLayout>;
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
@@ -517,10 +537,12 @@ struct GroupedConvolutionForwardKernel
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const InDataType* a_ptr,
const WeiDataType* b_ptr,
OutDataType* c_ptr,
const GroupedConvFwdKernelArgsSpecialized& kargs)
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
const GroupedConvFwdKernelArgsSpecialized& kargs)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
@@ -537,7 +559,21 @@ struct GroupedConvolutionForwardKernel
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
}();
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
"Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported!");
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
},
number<NumDTensor>{});
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
}
template <typename TensorView>
@@ -559,24 +595,35 @@ struct GroupedConvolutionForwardKernel
sequence<true, true>{});
}();
const auto& ds_tensor_view = views.at(I2);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
},
number<NumDTensor>{});
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
const auto& c_tensor_view = views.at(I3);
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& c_pad_view = views.at(I2);
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
return make_tile_window(a_pad_view,
@@ -592,12 +639,21 @@ struct GroupedConvolutionForwardKernel
{i_n, 0});
}();
const auto ds_block_window = generate_tuple(
[&](auto i) {
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
},
number<NumDTensor>{});
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, b_block_window, c_block_window);
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
}
/**
@@ -614,6 +670,7 @@ struct GroupedConvolutionForwardKernel
*/
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* smem_ptr_0,
const GroupedConvFwdKernelArgsSpecialized& kargs,
@@ -622,7 +679,8 @@ struct GroupedConvolutionForwardKernel
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(a_ptr, b_ptr, c_ptr, kargs);
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -633,15 +691,16 @@ struct GroupedConvolutionForwardKernel
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr_0);
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
@@ -661,6 +720,7 @@ struct GroupedConvolutionForwardKernel
*/
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
@@ -670,7 +730,8 @@ struct GroupedConvolutionForwardKernel
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(a_ptr, b_ptr, c_ptr, kargs);
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -680,15 +741,16 @@ struct GroupedConvolutionForwardKernel
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr_0);
c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1);
}
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
@@ -719,7 +781,8 @@ struct GroupedConvolutionForwardKernel
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
RunGemm2LDS(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
}
}
else
@@ -728,7 +791,7 @@ struct GroupedConvolutionForwardKernel
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
}
}
}

View File

@@ -20,11 +20,13 @@ struct GroupedConvHostArgs : public conv::ConvParam
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
const void* in_ptr_,
const void* wei_ptr_,
const std::vector<const void*> ds_ptr_,
void* out_ptr_,
index_t k_batch_)
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
wei_ptr(wei_ptr_),
ds_ptr(ds_ptr_),
out_ptr(out_ptr_),
k_batch(k_batch_)
{
@@ -32,6 +34,7 @@ struct GroupedConvHostArgs : public conv::ConvParam
const void* in_ptr;
const void* wei_ptr;
const std::vector<const void*> ds_ptr;
void* out_ptr;
index_t k_batch;
};
@@ -62,13 +65,23 @@ template <index_t NDimSpatial_,
ConvolutionSpecialization ConvSpecialization_,
typename InLayout_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_>
struct GroupedConvTraits
{
private:
static constexpr auto generate_implicit_gemm_layout()
{
return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
number<DsLayout_::size()>{});
}
public:
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
true,
@@ -76,6 +89,8 @@ struct GroupedConvTraits
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
static constexpr index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};
} // namespace ck_tile