mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
add batch_stride into batched gemm (#314)
* add batch_stride
* fixed test
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 1c8126a4c2]
This commit is contained in:
@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
|
||||
@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
compute_ptr_offset_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
|
||||
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
|
||||
block_2_ctile_map_{
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
|
||||
M01_{M01},
|
||||
@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
|
||||
@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int BatchStrideA,
|
||||
int BatchStrideB,
|
||||
int BatchStrideC,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({row * stride, stride, 1}));
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({col * stride, 1, stride}));
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<BDataType> b_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
|
||||
Tensor<CDataType> c_g_m_n_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
|
||||
Tensor<CDataType> c_g_m_n_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
|
||||
@@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
|
||||
@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
|
||||
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
|
||||
|
||||
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
|
||||
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
|
||||
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
|
||||
|
||||
bool pass = ck::profiler::
|
||||
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
do_verification,
|
||||
@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? DefaultStrideA : StrideA,
|
||||
(StrideB < 0) ? DefaultStrideB : StrideB,
|
||||
(StrideC < 0) ? DefaultStrideC : StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
BatchCount);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
|
||||
@@ -25,19 +25,19 @@ int main()
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
|
||||
true, 1, false, 1, M, N, K, K, N, N, BatchCount);
|
||||
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
|
||||
true, 1, false, 1, M, N, K, K, K, N, BatchCount);
|
||||
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
|
||||
true, 1, false, 1, M, N, K, M, N, N, BatchCount);
|
||||
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
|
||||
true, 1, false, 1, M, N, K, M, K, N, BatchCount);
|
||||
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
|
||||
return pass ? 0 : 1;
|
||||
|
||||
Reference in New Issue
Block a user