add batch_stride into batched gemm (#314)

* add batch_stride

* fixed test

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
zjing14
2022-07-01 01:35:37 -05:00
committed by GitHub
parent 0dcb3496cf
commit 1c8126a4c2
5 changed files with 52 additions and 17 deletions

View File

@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,

View File

@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1,
1,
a_element_op,
@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1,
1,
a_element_op,