Wmma support for gemm_multiply_multiply_wp (#3278)

* Initial implementation with splitK support

* Add gfx11 support

* Fix compilation error

* Add instances

* Add irregular instances

* Fix GetBuffer arguments

* Minor changes

* Address review comments

* Fix compilation errors

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-03 16:38:23 +01:00
committed by GitHub
parent f29b67cf9b
commit 161835533b
30 changed files with 2482 additions and 86 deletions

View File

@@ -30,7 +30,7 @@ void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int KLane = ck::get_warp_size() / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
@@ -156,8 +156,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
@@ -345,6 +345,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;