mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
added ab_elementwise_op support into splitK Gemm (#956)
* add ab_elementwise * fixed ci * fixed a merge issue * fixed pr comments * fixed a conflict * remove 61_example --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
PipelineVer,
|
||||
ComputeType>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t MPadded_,
|
||||
index_t NPadded_,
|
||||
index_t KPadded_,
|
||||
index_t K0_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
MPadded_,
|
||||
NPadded_,
|
||||
KPadded_,
|
||||
K0_,
|
||||
k_batch_),
|
||||
a_element_op(a_element_op_),
|
||||
b_element_op(b_element_op_),
|
||||
c_element_op(c_element_op_)
|
||||
{
|
||||
}
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
};
|
||||
|
||||
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
|
||||
|
||||
// Invoker
|
||||
@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
karg.M * karg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<typename GridwiseGemm::Argument>(karg),
|
||||
b2c_map,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DefaultBlock2CTileMap>;
|
||||
DefaultBlock2CTileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DefaultBlock2CTileMap>;
|
||||
DefaultBlock2CTileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DefaultBlock2CTileMap>;
|
||||
DefaultBlock2CTileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DefaultBlock2CTileMap>;
|
||||
DefaultBlock2CTileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
return Argument(p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
KBatch};
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
KBatch);
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
Reference in New Issue
Block a user