mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
mfma using asm, device result correct, host result need to check
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
@@ -315,8 +315,12 @@ int main(int argc, char* argv[])
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
// a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1.0, 1.0});
|
||||
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.0, 1.0});
|
||||
ck::utils::FillConstant<A0DataType>{ck::type_convert<A0DataType>(ck::float2_t(1.0f))}(
|
||||
a0_t_k_k);
|
||||
ck::utils::FillConstant<B0DataType>{ck::type_convert<B0DataType>(ck::float2_t(1.0f))}(
|
||||
b0_e_n_k);
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
@@ -360,6 +364,78 @@ int main(int argc, char* argv[])
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
#if 1
|
||||
printf("a0_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data >> 4 & 0xf));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k).data & 0xf));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("a1_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
{
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("b0_e_n_k:\n");
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data >> 4 & 0xf));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n).data & 0xf));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("b1_e_n_k:\n");
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
|
||||
@@ -778,6 +778,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
#if 0
|
||||
printf("bidx: %u, bidy: %u, tidx: %u, a_thread_vec: %02x, %02x, %02x, %02x, b_thread_vec: %02x, %02x, %02x, %02x,"
|
||||
"a_scale_thread_vec: %02x, b_scale_vec: %02x\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
|
||||
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()[Number<3>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeB>()[Number<3>{}]))),
|
||||
|
||||
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<0>{}].unpack(Number<0>{})),
|
||||
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<0>{}].unpack(Number<1>{})),
|
||||
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<1>{}].unpack(Number<0>{})),
|
||||
// type_convert<float>(a_thread_vec.template AsType<ComputeTypeA>()[Number<1>{}].unpack(Number<1>{})),
|
||||
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<0>{}].unpack(Number<0>{})),
|
||||
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<0>{}].unpack(Number<1>{})),
|
||||
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<1>{}].unpack(Number<0>{})),
|
||||
// type_convert<float>(b_thread_vec.template AsType<ComputeTypeB>()[Number<1>{}].unpack(Number<1>{})),
|
||||
*(reinterpret_cast<const uint8_t*>(&(a_scale_thread_vec.template AsType<AScaleDataType>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint8_t*>(&(b_scale_thread_vec.template AsType<BScaleDataType>()[Number<0>{}]))));
|
||||
#endif
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
@@ -796,8 +824,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp<BlockGemmPipelineSched
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
#if 0
|
||||
printf("bidx: %u, bidx: %u, tidx: %u, c_thread_buf: %f, %f, %f, %f\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
(c_thread_buf[Number<0>{}]),
|
||||
(c_thread_buf[Number<1>{}]),
|
||||
(c_thread_buf[Number<2>{}]),
|
||||
(c_thread_buf[Number<3>{}]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
#if 0
|
||||
printf("bidx: %u, bidx: %u, tidx: %u, c_thread_buf: %f, %f, %f, %f\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<0>{}]),
|
||||
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<1>{}]),
|
||||
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<2>{}]),
|
||||
type_convert<float>(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType<float>()[Number<3>{}]));
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1310,7 +1310,7 @@ struct GridwiseMoeGemmMX
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
|
||||
});
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
|
||||
@@ -762,6 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
#if 1
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
@@ -773,6 +774,43 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 "
|
||||
"op_sel:[0,0]"
|
||||
"op_sel_hi:[0,0]"
|
||||
"cbsz:4"
|
||||
" blgp:4"
|
||||
: "+v"(reg_c.template AsType<float4_t>()(Number<0>{}))
|
||||
: "v"(bit_cast<int32x4_t>(arg_a)),
|
||||
"v"(bit_cast<int32x4_t>(arg_b)),
|
||||
"v"(reg_c.template AsType<float4_t>()[Number<0>{}]),
|
||||
"v"(scale_a),
|
||||
"v"(scale_b));
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
printf("bidx: %u, bidy: %u, tid: %u, A: %08x, %08x, %08x, %08x,"
|
||||
"B:%08x, %08x, %08x, %08x, a_scale: %08x, b_scale: %08x, "
|
||||
"reg_c: %f, %f, %f, %f\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
bit_cast<uint32_t>(arg_a[0]),
|
||||
bit_cast<uint32_t>(arg_a[1]),
|
||||
bit_cast<uint32_t>(arg_a[2]),
|
||||
bit_cast<uint32_t>(arg_a[3]),
|
||||
bit_cast<uint32_t>(arg_b[0]),
|
||||
bit_cast<uint32_t>(arg_b[1]),
|
||||
bit_cast<uint32_t>(arg_b[2]),
|
||||
bit_cast<uint32_t>(arg_b[3]),
|
||||
*(reinterpret_cast<const uint32_t*>(&(scale_a))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(scale_b))),
|
||||
reg_c.template AsType<float>()[Number<0>{}],
|
||||
reg_c.template AsType<float>()[Number<1>{}],
|
||||
reg_c.template AsType<float>()[Number<2>{}],
|
||||
reg_c.template AsType<float>()[Number<3>{}]);
|
||||
#endif
|
||||
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
|
||||
Reference in New Issue
Block a user