mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
fix settings for example, fix some things in pipeline
This commit is contained in:
@@ -31,7 +31,7 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false>
|
||||
bool UsePersistentKernel = true>
|
||||
float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
@@ -83,7 +83,7 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
UsePersistentKernel,
|
||||
GemmConfig::NumWaveGroups,
|
||||
true>;
|
||||
false>;
|
||||
|
||||
using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -152,9 +152,9 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "32", "m dimension")
|
||||
.insert("n", "512", "n dimension")
|
||||
.insert("k", "256", "k dimension")
|
||||
arg_parser.insert("m", "4096", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
@@ -169,7 +169,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
|
||||
@@ -39,7 +39,8 @@ struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_GemmConfig16
|
||||
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -70,3 +71,17 @@ struct MXfp4_GemmConfig16
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
@@ -49,25 +49,25 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
|
||||
// Scale tensors
|
||||
// Assuming block scale 32
|
||||
ck_tile::index_t scale_n_size = N / 32;
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
ck_tile::index_t scale_k_size = K / 32;
|
||||
ck_tile::HostTensor<ck_tile::e8m0_t> scale_a_host(
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(
|
||||
ck_tile::HostTensorDescriptor({M, scale_k_size}, {scale_k_size, 1}));
|
||||
ck_tile::HostTensor<ck_tile::e8m0_t> scale_b_host(
|
||||
ck_tile::HostTensorDescriptor({scale_k_size, scale_n_size}, {scale_n_size, 1}));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(
|
||||
ck_tile::HostTensorDescriptor({scale_k_size, N}, {1, scale_k_size}));
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
|
||||
ck_tile::FillUniformDistribution<ck_tile::e8m0_t>{-1.f, 1.f}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ck_tile::e8m0_t>{-1.f, 1.f}(scale_b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_b_host);
|
||||
break;
|
||||
case 1:
|
||||
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
|
||||
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
|
||||
ck_tile::FillConstant<ck_tile::e8m0_t>{ck_tile::e8m0_t(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ck_tile::e8m0_t>{ck_tile::e8m0_t(1.f)}(scale_b_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
|
||||
// Scale pointers
|
||||
using ScaleM = ck_tile::MXScalePointer<1, 32>; // per-token
|
||||
using ScaleN = ck_tile::MXScalePointer<32, 32>; // per-block
|
||||
using ScaleM = ck_tile::MXScalePointer<1, 32>; // in blocks of 32 in K
|
||||
using ScaleN = ck_tile::MXScalePointer<1, 32>;
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
@@ -104,14 +104,31 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
|
||||
(void)ave_time;
|
||||
|
||||
bool pass = true;
|
||||
if(validation > 0)
|
||||
{
|
||||
// get output data from device
|
||||
c_dev_buf.FromDevice(c_host.data());
|
||||
// TODO: Implement validation logic (reference GEMM with scales)
|
||||
// For now just print success if it runs
|
||||
std::cout << "Validation not implemented yet." << std::endl;
|
||||
|
||||
// compute reference
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host);
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return 0;
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
|
||||
int run_mx_gemm_example(int argc, char* argv[])
|
||||
@@ -126,24 +143,28 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
std::string mx_prec = arg_parser.get_str("mx_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
MXfp4_GemmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
MXfp4_GemmConfig16,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
MXfp8_GemmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Only fp4xfp4 is supported currently!");
|
||||
throw std::runtime_error("Only fp4 and fp8 is supported currently!");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user