This commit is contained in:
Ding, Yi
2025-05-07 09:40:47 +00:00
parent d12b750bb0
commit 8d51a4ae96
7 changed files with 72 additions and 42 deletions

View File

@@ -200,10 +200,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
switch(config.init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale);
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(ck::float2_t(1.0f))}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(ck::float2_t(1.0f))}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
@@ -347,16 +347,16 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
std::cout << "Comparing results..." << std::endl;
}
if(config.init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
// if(config.init_method == 0)
// {
// auto expected = static_cast<float>(K);
// auto computed = type_convert<float>(c_m_n_device_result(1, 12));
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
<< std::endl;
}
// res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
// std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
// << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
// << std::endl;
// }
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,

View File

@@ -45,24 +45,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
64, // BlockSize: Thread block size
16, // MPerBlock
16, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
16, // MPerXDL
16, // NPerXDL
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
@@ -71,8 +71,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA

View File

@@ -344,13 +344,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
// a_thread_copy_.Run(
// a_block_desc_m0_m1_m2_k,
// make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
// a_thread_buf);
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -406,13 +406,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
CK_TILE_PRINT<xdlops_gemm.K1PerXdlops>();
CK_TILE_PRINT<decltype(xdlops_gemm)>();
// CK_TILE_PRINT<xdlops_gemm.K1PerXdlops>();
// CK_TILE_PRINT<decltype(xdlops_gemm)>();
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / 2>::type;
// mfma input type = pk_f4_t, 32
CK_TILE_PRINT<mfma_input_type_a>();
// CK_TILE_PRINT<mfma_input_type_a>();
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / 2>::type;
@@ -538,6 +538,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
});
});
});
printf("a_thread_buf: %x %x\n",
*reinterpret_cast<const uint8_t*>(&a_scale_thread_buf[I0]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_buf[I0]));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {

View File

@@ -167,15 +167,16 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
//
// Should be a multiple of k_per_blk.
// TODO: Move this to blockwise pipeline base
static constexpr index_t KPack = // = num of pk_f4
math::max(lcm_AK1_BK1, // num of pk_f4
// KPack in packed data types for pk A/B
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk /
2); // num of f4
2);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -1567,6 +1568,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// shuffle C and write out
{
// printf("c_thread_buf %f %f\n", c_thread_buf[I0], c_thread_buf[I1]);
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");

View File

@@ -1140,6 +1140,11 @@ struct MfmaSelector
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16>()
@@ -1443,7 +1448,7 @@ struct XdlopsGemm
const ScaleB& b_scale_thread,
FloatC& p_c_thread) const
{
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
static_for<0, KPack * 2 / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(

View File

@@ -1572,6 +1572,11 @@ inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
return f4_convert_rne(x);
#endif
}
template <>
inline __host__ __device__ f4x2_pk_t type_convert<f4x2_pk_t, float2_t>(float2_t x)
{
return static_cast<f4x2_pk_t>(type_convert<f4x2_t>(x));
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>

View File

@@ -84,6 +84,7 @@ struct ReferenceMXGemm : public device::BaseOperator
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
printf("K: %d\n", K);
for(size_t m = 0; m < M; m++)
{
@@ -95,15 +96,29 @@ struct ReferenceMXGemm : public device::BaseOperator
if(k % 2 == 1)
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
f4_t(arg.a_m_k_(m, k / 2).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
else
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
f4_t(arg.a_m_k_(m, k / 2).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
if(m == 0)
{
printf("a_m_k_scaled(%zu, %zu): %f = %f * %f\n",
m,
k,
a_m_k_scaled(m, k),
k % 2 == 1
? type_convert<ComputeTypeA>(f4_t(
arg.a_m_k_(m, k / 2).template unpack<>(Number<1>{})))
: type_convert<ComputeTypeA>(f4_t(
arg.a_m_k_(m, k / 2).template unpack<>(Number<0>{}))),
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)));
}
}
else
{
@@ -124,13 +139,13 @@ struct ReferenceMXGemm : public device::BaseOperator
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
f4_t(arg.b_k_n_(k / 2, n).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
else
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
f4_t(arg.b_k_n_(k / 2, n).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}