mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Multi AB support for wave transfer (#3578)
* Add multi AB support to wave transfer * Improviments to multi ABD examples * Add instances and use intrawave v1 instead of interwave * Apply changes to other transfers * Wave transfer: add support for multiple internal vgpr buffers * Fix compilation error gfx11
This commit is contained in:
@@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8, 32, 1>,
|
||||
S<8, 16, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -174,6 +174,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
|
||||
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));
|
||||
|
||||
@@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8, 32, 1>,
|
||||
S<8, 16, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -133,7 +133,7 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -170,6 +170,28 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));
|
||||
|
||||
@@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<4, 64, 1>,
|
||||
S<4, 16, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -233,6 +233,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideD = f_get_default_stride(M, N, StrideD, DLayout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
@@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8, 32, 1>,
|
||||
S<8, 16, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -173,6 +173,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
|
||||
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
|
||||
Reference in New Issue
Block a user