mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Consuming binary ops to do A+B / A-B
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "gridwise_binary_elementwise_1d.hpp"
|
||||
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -66,6 +68,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto ScalarPerVector = Number<4>{};
|
||||
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t threadPerBlock)
|
||||
{
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<2>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<2>{});
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
|
||||
}
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -333,6 +370,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
@@ -426,6 +464,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m0_ =
|
||||
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m0_ =
|
||||
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
@@ -440,6 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
GridDesc_M0 c_grid_desc_m0_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
@@ -468,6 +520,35 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
using Substract = ck::tensor_operation::binary_element_wise::Substract;
|
||||
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
GridDesc_M0,
|
||||
Add,
|
||||
ScalarPerVector>;
|
||||
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
GridDesc_M0,
|
||||
Substract,
|
||||
ScalarPerVector>;
|
||||
const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
GridDesc_M0,
|
||||
Add>;
|
||||
const auto substract_kernel = kernel_elementwise_1d<GridwiseBinSubstract,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
GridDesc_M0,
|
||||
Substract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
|
||||
@@ -517,7 +598,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2 needed here!!!
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
substract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
Substract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
@@ -553,7 +646,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2 needed here!!!
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
Add{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -604,7 +709,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// // c_real = aux - aux_2 needed here!!!
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
substract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
Substract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
@@ -640,7 +757,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2 needed here!!!
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
Add{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -12,6 +12,39 @@ struct Add
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
};
|
||||
|
||||
struct Substract
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(float& dst, const float& src1, const float& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace binary_element_wise
|
||||
|
||||
Reference in New Issue
Block a user