Wmma support for gemm_ab_scale (#3314)

* Support gemm_ab_scale:

 - Add tests
 - Integrate scaling implementation in multiple D
 - Generalize existing b_scale for ab_scale
 - Add instances
 - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK
 - Add support for all layouts supported by xdl
 - Fix splitk xdl

* Fix copyright

* Wmma support for gemm_blockscale_wp (#3315)

* Support for  preshuffle with ab scale

 - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale
 - add support for AScaleLayout amnd BScaleLayout (can be different
   from ALayout and BLayout, respectively)
 - add Run method in v1 pipeline to support preshuffle + scaling
 - add support for preshuffle gemms in common invoker
 - Add splitk support

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-11 09:06:20 +01:00
committed by GitHub
parent d66e5f667c
commit ce99cab605
51 changed files with 5144 additions and 552 deletions

View File

@@ -109,8 +109,8 @@ bool profile_gemm_ab_scale_impl(int do_verification,
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
a1_m_k.GenerateTensorValue(GeneratorTensor_2<A1DataType>{-1, 2});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-1, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
@@ -302,7 +302,7 @@ bool profile_gemm_ab_scale_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
<< gb_per_sec << " GB/s, " << op_name << ", KBatch " << KBatch << std::endl;
if(tflops > best_tflops)
{

View File

@@ -29,7 +29,7 @@ void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int KLane = ck::get_warp_size() / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack