mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
@@ -7,6 +7,7 @@
|
||||
#include <variant>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
@@ -241,6 +242,27 @@ struct GemmConfigComputeV6 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeAsync : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_ASYNC;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
{
|
||||
@@ -375,6 +397,15 @@ struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
using CDataType = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_fp4_t;
|
||||
using BDataType = ck_tile::pk_fp4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
@@ -423,6 +454,15 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_ASYNC>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
{
|
||||
|
||||
@@ -30,7 +30,6 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
@@ -369,7 +368,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
sizeof(ADataType) * M * K / ck_tile::numeric_traits<ADataType>::PackedSize +
|
||||
sizeof(BDataType) * N * K / ck_tile::numeric_traits<BDataType>::PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
|
||||
@@ -262,6 +262,23 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
if(data_type == "fp4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::pk_fp4_t>::Pipeline ==
|
||||
ck_tile::GemmPipeline::COMPUTE_ASYNC &&
|
||||
GemmConfig<ck_tile::pk_fp4_t>::K_Warp_Tile == 128)
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::pk_fp4_t>,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
Reference in New Issue
Block a user