This commit is contained in:
Jakub Piasecki
2025-06-23 14:02:03 +00:00
parent 1ac8f1c744
commit f5345174e4
7 changed files with 89 additions and 56 deletions

View File

@@ -23,7 +23,7 @@ template <ck_tile::index_t NDimSpatial,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;

View File

@@ -105,4 +105,5 @@ auto create_args(int argc, char* argv[])
}
// host API
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);
using GroupedConvHostArgs = ck_tile::GroupedConvHostArgs<const void*, const void*, void*>;
float grouped_conv_fwd(const GroupedConvHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -32,7 +32,7 @@ template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat)
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, int n_warmup, int n_repeat)
{
float ave_time = grouped_conv_fwd<NDimSpatial,
InDataType,
@@ -143,7 +143,7 @@ int run_grouped_conv_fwd_example_with_layouts(
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.SetZero();
ck_tile::GroupedConvHostArgs args(conv_param,
ck_tile::GroupedConvHostArgs<const void*, const void*, void*> args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},

View File

@@ -435,6 +435,47 @@ struct GroupedConvolutionBackwardWeightKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
// not supported
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
// not supported
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
}
else
{
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
@@ -566,7 +607,8 @@ struct GroupedConvolutionBackwardWeightKernel
const InDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
WeiDataType* c_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
@@ -700,6 +742,7 @@ struct GroupedConvolutionBackwardWeightKernel
WeiDataType* c_ptr,
void* smem_ptr_0,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
@@ -711,8 +754,8 @@ struct GroupedConvolutionBackwardWeightKernel
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -751,6 +794,7 @@ struct GroupedConvolutionBackwardWeightKernel
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
@@ -761,8 +805,8 @@ struct GroupedConvolutionBackwardWeightKernel
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -787,6 +831,8 @@ struct GroupedConvolutionBackwardWeightKernel
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
@@ -794,8 +840,8 @@ struct GroupedConvolutionBackwardWeightKernel
// options
// conv_bwd_weight = Out * In = Weight
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + splitk_batch_offset.a_k_split_offset;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b + splitk_batch_offset.b_k_split_offset;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
// allocate LDS
@@ -809,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n);
}
}
else
@@ -818,7 +864,7 @@ struct GroupedConvolutionBackwardWeightKernel
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
}

View File

@@ -34,7 +34,7 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -103,7 +103,7 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -179,7 +179,7 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -366,6 +366,7 @@ struct GroupedConvolutionForwardKernel
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType>;
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
// TODO: Enable this
static constexpr bool IsSplitKSupported = false;
@@ -388,7 +389,7 @@ struct GroupedConvolutionForwardKernel
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args)
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdHostArgs& args)
{
const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
args.output_spatial_lengths_.end(),
@@ -401,7 +402,7 @@ struct GroupedConvolutionForwardKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvHostArgs& hostArgs)
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
{
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
}

View File

@@ -14,14 +14,15 @@ namespace ck_tile {
/// This structure is passed to Grouped Convolution Kernels when creating kernel
/// arguments object. It contain all necessary information required to
/// build proper kernel argument and launch kernel on GPU.
template <typename InPtr, typename WeiPtr, typename OutPtr>
struct GroupedConvHostArgs : public conv::ConvParam
{
CK_TILE_HOST GroupedConvHostArgs() = delete;
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
const void* in_ptr_,
const void* wei_ptr_,
InPtr in_ptr_,
WeiPtr wei_ptr_,
const std::vector<const void*> ds_ptr_,
void* out_ptr_,
OutPtr out_ptr_,
index_t k_batch_)
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
@@ -32,37 +33,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
{
}
const void* in_ptr;
const void* wei_ptr;
InPtr in_ptr;
WeiPtr wei_ptr;
const std::vector<const void*> ds_ptr;
void* out_ptr;
OutPtr out_ptr;
index_t k_batch;
};
struct GroupedConvBwdWeightHostArgs : public conv::ConvParam
{
CK_TILE_HOST GroupedConvBwdWeightHostArgs() = delete;
CK_TILE_HOST GroupedConvBwdWeightHostArgs(ConvParam conv_param,
const void* in_ptr_,
void* wei_ptr_,
const std::vector<const void*> ds_ptr_,
const void* out_ptr_,
index_t k_batch_)
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
wei_ptr(wei_ptr_),
ds_ptr(ds_ptr_),
out_ptr(out_ptr_),
k_batch(k_batch_)
{
}
const void* in_ptr;
void* wei_ptr;
const std::vector<const void*> ds_ptr;
const void* out_ptr;
index_t k_batch;
};
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
template <index_t NDimSpatial_,
ConvolutionSpecialization ConvSpecialization_,

View File

@@ -415,7 +415,7 @@ struct TransformConvBwdWeightToGemm
#endif
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
CK_TILE_HOST auto make_out_grid_desc() const
CK_TILE_HOST auto make_out_grid_desc(const index_t GemmKBatch) const
{
// NWGK
const index_t NDoHoWoStride = G_ * K_;
@@ -423,7 +423,7 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
return make_naive_tensor_descriptor(make_tuple(N_ * Wo / GemmKBatch_, K_),
make_tuple(NDoHoWoStride, KStride));
}
@@ -538,23 +538,22 @@ struct TransformConvBwdWeightToGemm
// properties
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t GemmKBatch) const
{
// Assume NumGroupsToMerge == 1 for now
const index_t GemmKTotal = N_ * Wo_;
const index_t GemmKTotal = N_ * Wo_ / KBatch; // tmp
const index_t GemmM = K_ * NumGroupsToMerge;
const index_t GemmN = C_ * X_ * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = 1;
const index_t GemmK0 =
integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>(GemmKBatch);
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>(GemmKBatch);
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// A: output tensor comes in K_M
@@ -597,6 +596,13 @@ struct TransformConvBwdWeightToGemm
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -604,7 +610,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tuple(out_gemmkpad_gemmm_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkpad_gemmn_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}