fix settings for example, fix some things in pipeline

This commit is contained in:
Sami Remes
2025-12-19 12:35:03 -05:00
parent 6a4951cf8c
commit 86cc59e754
9 changed files with 105 additions and 115 deletions

View File

@@ -31,7 +31,7 @@ template <typename GemmConfig,
typename CLayout,
typename ScaleM,
typename ScaleN,
bool UsePersistentKernel = false>
bool UsePersistentKernel = true>
float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
@@ -83,7 +83,7 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
GemmConfig::UseStructuredSparsity,
UsePersistentKernel,
GemmConfig::NumWaveGroups,
true>;
false>;
using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
BDataType,
@@ -152,9 +152,9 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "32", "m dimension")
.insert("n", "512", "n dimension")
.insert("k", "256", "k dimension")
arg_parser.insert("m", "4096", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
@@ -169,7 +169,6 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:constant(1)")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");

View File

@@ -39,7 +39,8 @@ struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
};
// GEMM config with 16x16 warp tile
struct MXfp4_GemmConfig16
struct MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -70,3 +71,17 @@ struct MXfp4_GemmConfig16
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp4_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
};
// GEMM config with 16x16 warp tile
struct MXfp8_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
};

View File

@@ -49,25 +49,25 @@ int run_mx_gemm_with_layouts(int argc,
// Scale tensors
// Assuming block scale 32
ck_tile::index_t scale_n_size = N / 32;
using ScaleType = ck_tile::e8m0_t;
ck_tile::index_t scale_k_size = K / 32;
ck_tile::HostTensor<ck_tile::e8m0_t> scale_a_host(
ck_tile::HostTensor<ScaleType> scale_a_host(
ck_tile::HostTensorDescriptor({M, scale_k_size}, {scale_k_size, 1}));
ck_tile::HostTensor<ck_tile::e8m0_t> scale_b_host(
ck_tile::HostTensorDescriptor({scale_k_size, scale_n_size}, {scale_n_size, 1}));
ck_tile::HostTensor<ScaleType> scale_b_host(
ck_tile::HostTensorDescriptor({scale_k_size, N}, {1, scale_k_size}));
switch(init_method)
{
case 0:
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
ck_tile::FillUniformDistribution<ck_tile::e8m0_t>{-1.f, 1.f}(scale_a_host);
ck_tile::FillUniformDistribution<ck_tile::e8m0_t>{-1.f, 1.f}(scale_b_host);
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_b_host);
break;
case 1:
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
ck_tile::FillConstant<ck_tile::e8m0_t>{ck_tile::e8m0_t(1.f)}(scale_a_host);
ck_tile::FillConstant<ck_tile::e8m0_t>{ck_tile::e8m0_t(1.f)}(scale_b_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
break;
}
@@ -83,8 +83,8 @@ int run_mx_gemm_with_layouts(int argc,
scale_b_dev_buf.ToDevice(scale_b_host.data());
// Scale pointers
using ScaleM = ck_tile::MXScalePointer<1, 32>; // per-token
using ScaleN = ck_tile::MXScalePointer<32, 32>; // per-block
using ScaleM = ck_tile::MXScalePointer<1, 32>; // in blocks of 32 in K
using ScaleN = ck_tile::MXScalePointer<1, 32>;
ScaleM scale_m(reinterpret_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()));
ScaleN scale_n(reinterpret_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()));
@@ -104,14 +104,31 @@ int run_mx_gemm_with_layouts(int argc,
(void)ave_time;
bool pass = true;
if(validation > 0)
{
// get output data from device
c_dev_buf.FromDevice(c_host.data());
// TODO: Implement validation logic (reference GEMM with scales)
// For now just print success if it runs
std::cout << "Validation not implemented yet." << std::endl;
// compute reference
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host);
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return 0;
return pass ? 0 : -1;
}
int run_mx_gemm_example(int argc, char* argv[])
@@ -126,24 +143,28 @@ int run_mx_gemm_example(int argc, char* argv[])
std::string mx_prec = arg_parser.get_str("mx_prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int persistent_opt = arg_parser.get_int("persistent");
if(a_layout == "R" && b_layout == "C")
{
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
{
if(persistent_opt == 0)
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
float,
MXfp4_GemmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only non-persistent kernels are supported currently!");
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
float,
MXfp4_GemmConfig16,
true>(argc, argv, Row{}, Col{}, Row{});
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
float,
MXfp8_GemmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Only fp4xfp4 is supported currently!");
throw std::runtime_error("Only fp4 and fp8 is supported currently!");
}
}
else