update codes

This commit is contained in:
mtgu0705
2025-08-30 03:19:07 -05:00
parent 9c37e55d13
commit 16993acd1d
9 changed files with 2095 additions and 88 deletions

View File

@@ -11,7 +11,7 @@
#include <type_traits>
#include "ck_tile/host.hpp"
#include "mx_prec_flatmm.hpp"
#include "mx_flatmm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -99,17 +99,17 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenMXFlatmmPipeline =
ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -137,7 +137,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
BlockedXDLN_PerWarp>>;
using Kernel =
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
ck_tile::MXFlatmmKernel<TilePartitioner, CodegenMXFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -4,37 +4,12 @@
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
#include "mxfp4_flatmm.hpp"

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};

View File

@@ -239,58 +239,25 @@ int run_mx_flatmm_with_layouts(int argc,
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes());
b_origin_dev_buf.ToDevice(b_origin_host.data());
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes());
c_m_n_host_ref.SetZero();
ck_tile::HostTensor<AccDataType> scale_A(
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleDataType, AccDataType, CDataType>(
a_host.data(),
b_origin_host.data(),
c_m_n_host_ref.data(),
scale_a.data(),
scale_b.data());
// scaleA = 1 has no effect on the result
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
scale_A_dev_buf.ToDevice(scale_A.data());
// convert scale_b from e8m0 to float
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
c_gpu_ref_dev_buf.SetZero();
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_dev_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gpu_ref_dev_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
stride_C,
M,
DequantGranularityN,
DequantGranularityK,
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;