mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
99
example/ck_tile/42_mx_gemm/mx_gemm.hpp
Normal file
99
example/ck_tile/42_mx_gemm/mx_gemm.hpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
|
||||
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
{
|
||||
using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>;
|
||||
|
||||
MXGemmHostArgs(const void* a_ptr,
|
||||
const void* b_ptr,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ScaleM scale_m_,
|
||||
ScaleN scale_n_)
|
||||
: Base({a_ptr},
|
||||
{b_ptr},
|
||||
{},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{},
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
ScaleM scale_m;
|
||||
ScaleN scale_n;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 512;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false; // comp_async uses double buffer
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
Reference in New Issue
Block a user