mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
tmp
This commit is contained in:
@@ -23,7 +23,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
|
||||
@@ -105,4 +105,5 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);
|
||||
using GroupedConvHostArgs = ck_tile::GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
float grouped_conv_fwd(const GroupedConvHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -32,7 +32,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat)
|
||||
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, int n_warmup, int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
@@ -143,7 +143,7 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
|
||||
ck_tile::GroupedConvHostArgs args(conv_param,
|
||||
ck_tile::GroupedConvHostArgs<const void*, const void*, void*> args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
|
||||
@@ -435,6 +435,47 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1);
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
// not supported
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
// not supported
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
@@ -566,7 +607,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
@@ -700,6 +742,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
WeiDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
@@ -711,8 +754,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -751,6 +794,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
@@ -761,8 +805,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -787,6 +831,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
|
||||
const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
|
||||
@@ -794,8 +840,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
// options
|
||||
// conv_bwd_weight = Out * In = Weight
|
||||
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + splitk_batch_offset.a_k_split_offset;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b + splitk_batch_offset.b_k_split_offset;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
// allocate LDS
|
||||
@@ -809,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -818,7 +864,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -103,7 +103,7 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -179,7 +179,7 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -366,6 +366,7 @@ struct GroupedConvolutionForwardKernel
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType>;
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = false;
|
||||
@@ -388,7 +389,7 @@ struct GroupedConvolutionForwardKernel
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
@@ -401,7 +402,7 @@ struct GroupedConvolutionForwardKernel
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvHostArgs& hostArgs)
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
|
||||
@@ -14,14 +14,15 @@ namespace ck_tile {
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr>
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
|
||||
const void* in_ptr_,
|
||||
const void* wei_ptr_,
|
||||
InPtr in_ptr_,
|
||||
WeiPtr wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
void* out_ptr_,
|
||||
OutPtr out_ptr_,
|
||||
index_t k_batch_)
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
@@ -32,37 +33,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
}
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
InPtr in_ptr;
|
||||
WeiPtr wei_ptr;
|
||||
const std::vector<const void*> ds_ptr;
|
||||
void* out_ptr;
|
||||
OutPtr out_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct GroupedConvBwdWeightHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvBwdWeightHostArgs() = delete;
|
||||
CK_TILE_HOST GroupedConvBwdWeightHostArgs(ConvParam conv_param,
|
||||
const void* in_ptr_,
|
||||
void* wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
const void* out_ptr_,
|
||||
index_t k_batch_)
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* in_ptr;
|
||||
void* wei_ptr;
|
||||
const std::vector<const void*> ds_ptr;
|
||||
const void* out_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
|
||||
@@ -415,7 +415,7 @@ struct TransformConvBwdWeightToGemm
|
||||
#endif
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
CK_TILE_HOST auto make_out_grid_desc(const index_t GemmKBatch) const
|
||||
{
|
||||
// NWGK
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
@@ -423,7 +423,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo / GemmKBatch_, K_),
|
||||
make_tuple(NDoHoWoStride, KStride));
|
||||
}
|
||||
|
||||
@@ -538,23 +538,22 @@ struct TransformConvBwdWeightToGemm
|
||||
// properties
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
|
||||
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t GemmKBatch) const
|
||||
{
|
||||
// Assume NumGroupsToMerge == 1 for now
|
||||
const index_t GemmKTotal = N_ * Wo_;
|
||||
const index_t GemmKTotal = N_ * Wo_ / KBatch; // tmp
|
||||
const index_t GemmM = K_ * NumGroupsToMerge;
|
||||
const index_t GemmN = C_ * X_ * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = 1;
|
||||
const index_t GemmK0 =
|
||||
integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
|
||||
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
|
||||
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>(GemmKBatch);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>(GemmKBatch);
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// A: output tensor comes in K_M
|
||||
@@ -597,6 +596,13 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -604,7 +610,7 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkpad_gemmm_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user