mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
add batch_stride into batched gemm (#314)
* add batch_stride * fixed test Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
|
||||
@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
compute_ptr_offset_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
|
||||
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
|
||||
block_2_ctile_map_{
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
|
||||
M01_{M01},
|
||||
@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
|
||||
Reference in New Issue
Block a user