[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)

[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-10 20:12:43 +00:00
committed by assistant-librarian[bot]
parent b8def2c724
commit 8f27f65d44
40 changed files with 2729 additions and 43 deletions

View File

@@ -7,6 +7,7 @@
#include <variant>
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
@@ -241,6 +242,27 @@ struct GemmConfigComputeV6 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
};
template <typename PrecType>
struct GemmConfigComputeAsync : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_ASYNC;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool UseStructuredSparsity = false;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
@@ -375,6 +397,15 @@ struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
using CDataType = int32_t;
};
template <>
struct GemmTypeConfig<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t>
{
using ADataType = ck_tile::pk_fp4_t;
using BDataType = ck_tile::pk_fp4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
@@ -423,6 +454,15 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_ASYNC>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
{