mfma using asm, device result correct, host result need to check

This commit is contained in:
mtgu0705
2025-05-13 20:57:34 -05:00
parent 6dfe24c53e
commit 1bbb50b212
4 changed files with 167 additions and 4 deletions

View File

@@ -18,7 +18,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
@@ -315,8 +315,12 @@ int main(int argc, char* argv[])
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
// a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1.0, 1.0});
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.0, 1.0});
ck::utils::FillConstant<A0DataType>{ck::type_convert<A0DataType>(ck::float2_t(1.0f))}(
a0_t_k_k);
ck::utils::FillConstant<B0DataType>{ck::type_convert<B0DataType>(ck::float2_t(1.0f))}(
b0_e_n_k);
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
@@ -360,6 +364,78 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
#if 1
printf("a0_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
{
for(int k = 0; k < K; ++k)
{
if(k % 2 == 0)
{
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data >> 4 & 0xf));
}
else
{
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data & 0xf));
}
}
printf("\n");
}
printf("\n");
}
printf("a1_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
{
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
printf("%f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
}
printf("\n");
}
printf("\n");
}
printf("b0_e_n_k:\n");
for(int e = 0; e < experts; ++e)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
if(k % 2 == 0)
{
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data >> 4 & 0xf));
}
else
{
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data & 0xf));
}
}
printf("\n");
}
printf("\n");
}
printf("b1_e_n_k:\n");
for(int e = 0; e < experts; ++e)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
printf("%f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
}
printf("\n");
}
printf("\n");
}
#endif
// do GEMM
auto device_op = DeviceOpInstance{};

View File

@@ -778,6 +778,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
});
#if 0
printf("bidx: %u, bidy: %u, tidx: %u, a_thread_vec: %02x, %02x, %02x, %02x, b_thread_vec: %02x, %02x, %02x, %02x,"
"a_scale_thread_vec: %02x, b_scale_vec: %02x\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<0>{}]))),
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<1>{}]))),
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<2>{}]))),
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<3>{}]))),
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<0>{}]))),
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<1>{}]))),
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<2>{}]))),
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<3>{}]))),
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<0>{}].unpack(Number<0>{})),
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<0>{}].unpack(Number<1>{})),
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<1>{}].unpack(Number<0>{})),
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<1>{}].unpack(Number<1>{})),
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<0>{}].unpack(Number<0>{})),
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<0>{}].unpack(Number<1>{})),
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<1>{}].unpack(Number<0>{})),
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<1>{}].unpack(Number<1>{})),
*(reinterpret_cast<const uint8_t*>(&(a_scale_thread_vec.template AsType<AScaleDataType>()[Number<0>{}]))),
*(reinterpret_cast<const uint8_t*>(&(b_scale_thread_vec.template AsType<BScaleDataType>()[Number<0>{}]))));
#endif
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
@@ -796,8 +824,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
#if 0
printf("bidx: %u, bidx: %u, tidx: %u, c_thread_buf: %f, %f, %f, %f\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
(c_thread_buf[Number<0>{}]),
(c_thread_buf[Number<1>{}]),
(c_thread_buf[Number<2>{}]),
(c_thread_buf[Number<3>{}]));
#endif
});
});
#if 0
printf("bidx: %u, bidx: %u, tidx: %u, c_thread_buf: %f, %f, %f, %f\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<0>{}]),
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<1>{}]),
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<2>{}]),
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<3>{}]));
#endif
}
}

View File

@@ -1310,7 +1310,7 @@ struct GridwiseMoeGemmMX
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));

View File

@@ -762,6 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
#if 1
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
@@ -773,6 +774,43 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
scale_a,
0, // OPSEL
scale_b);
#else
asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 "
"op_sel:[0,0]"
"op_sel_hi:[0,0]"
"cbsz:4"
" blgp:4"
: "+v"(reg_c.template AsType<float4_t>()(Number<0>{}))
: "v"(bit_cast<int32x4_t>(arg_a)),
"v"(bit_cast<int32x4_t>(arg_b)),
"v"(reg_c.template AsType<float4_t>()[Number<0>{}]),
"v"(scale_a),
"v"(scale_b));
#endif
#if 1
printf("bidx: %u, bidy: %u, tid: %u, A: %08x, %08x, %08x, %08x,"
"B:%08x, %08x, %08x, %08x, a_scale: %08x, b_scale: %08x, "
"reg_c: %f, %f, %f, %f\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
bit_cast<uint32_t>(arg_a[0]),
bit_cast<uint32_t>(arg_a[1]),
bit_cast<uint32_t>(arg_a[2]),
bit_cast<uint32_t>(arg_a[3]),
bit_cast<uint32_t>(arg_b[0]),
bit_cast<uint32_t>(arg_b[1]),
bit_cast<uint32_t>(arg_b[2]),
bit_cast<uint32_t>(arg_b[3]),
*(reinterpret_cast<const uint32_t*>(&(scale_a))),
*(reinterpret_cast<const uint32_t*>(&(scale_b))),
reg_c.template AsType<float>()[Number<0>{}],
reg_c.template AsType<float>()[Number<1>{}],
reg_c.template AsType<float>()[Number<2>{}],
reg_c.template AsType<float>()[Number<3>{}]);
#endif
#else
ignore = reg_a;
ignore = scale_a;