Multi AB support for wave transfer (#3578)

* Add multi AB support to wave transfer

* Improviments to multi ABD examples

* Add instances and use intrawave v1 instead of interwave

* Apply changes to other transfers

* Wave transfer: add support for multiple internal vgpr buffers

* Fix compilation error gfx11
This commit is contained in:
Enrico Degregori
2026-01-29 19:29:40 +01:00
committed by GitHub
parent fabac7e2c3
commit f16d9100e4
21 changed files with 374 additions and 188 deletions

View File

@@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<8, 32, 1>,
S<8, 16, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
8,
0,
1,
@@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v1>;
int main(int argc, char* argv[])
{
@@ -174,6 +174,29 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));

View File

@@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<8, 32, 1>,
S<8, 16, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
8,
0,
1,
@@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v1>;
int main(int argc, char* argv[])
{
@@ -133,7 +133,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -170,6 +170,28 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));

View File

@@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<4, 64, 1>,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
1,
1,
8,
8,
0,
1,
@@ -233,6 +233,29 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideD = f_get_default_stride(M, N, StrideD, DLayout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

View File

@@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<8, 32, 1>,
S<8, 16, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
8,
0,
1,
@@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v1>;
int main(int argc, char* argv[])
{
@@ -173,6 +173,29 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));