mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
wip2
This commit is contained in:
@@ -200,10 +200,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: // Initializations for development and debugging
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale);
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(ck::float2_t(1.0f))}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(ck::float2_t(1.0f))}(b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Init A = {1}" << std::endl;
|
||||
@@ -347,16 +347,16 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
std::cout << "Comparing results..." << std::endl;
|
||||
}
|
||||
|
||||
if(config.init_method == 0)
|
||||
{
|
||||
auto expected = static_cast<float>(K);
|
||||
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
|
||||
// if(config.init_method == 0)
|
||||
// {
|
||||
// auto expected = static_cast<float>(K);
|
||||
// auto computed = type_convert<float>(c_m_n_device_result(1, 12));
|
||||
|
||||
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
|
||||
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
|
||||
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
|
||||
<< std::endl;
|
||||
}
|
||||
// res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
|
||||
// std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
|
||||
// << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
|
||||
// << std::endl;
|
||||
// }
|
||||
|
||||
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
|
||||
@@ -45,24 +45,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // BlockSize: Thread block size
|
||||
16, // MPerBlock
|
||||
16, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
1, // NXdlPerWave
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
@@ -71,8 +71,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
|
||||
@@ -344,13 +344,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
// a_thread_copy_.Run(
|
||||
// a_block_desc_m0_m1_m2_k,
|
||||
// make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
// a_block_buf,
|
||||
// a_thread_desc_,
|
||||
// make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
// a_thread_buf);
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
@@ -406,13 +406,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
CK_TILE_PRINT<xdlops_gemm.K1PerXdlops>();
|
||||
CK_TILE_PRINT<decltype(xdlops_gemm)>();
|
||||
// CK_TILE_PRINT<xdlops_gemm.K1PerXdlops>();
|
||||
// CK_TILE_PRINT<decltype(xdlops_gemm)>();
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
// mfma input type = pk_f4_t, 32
|
||||
CK_TILE_PRINT<mfma_input_type_a>();
|
||||
// CK_TILE_PRINT<mfma_input_type_a>();
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
@@ -538,6 +538,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
});
|
||||
printf("a_thread_buf: %x %x\n",
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_buf[I0]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_buf[I0]));
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
|
||||
@@ -167,15 +167,16 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
//
|
||||
// Should be a multiple of k_per_blk.
|
||||
// TODO: Move this to blockwise pipeline base
|
||||
static constexpr index_t KPack = // = num of pk_f4
|
||||
math::max(lcm_AK1_BK1, // num of pk_f4
|
||||
// KPack in packed data types for pk A/B
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk /
|
||||
2); // num of f4
|
||||
2);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -1567,6 +1568,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// printf("c_thread_buf %f %f\n", c_thread_buf[I0], c_thread_buf[I1]);
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
@@ -1140,6 +1140,11 @@ struct MfmaSelector
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
@@ -1443,7 +1448,7 @@ struct XdlopsGemm
|
||||
const ScaleB& b_scale_thread,
|
||||
FloatC& p_c_thread) const
|
||||
{
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
static_for<0, KPack * 2 / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
|
||||
@@ -1572,6 +1572,11 @@ inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
|
||||
return f4_convert_rne(x);
|
||||
#endif
|
||||
}
|
||||
template <>
|
||||
inline __host__ __device__ f4x2_pk_t type_convert<f4x2_pk_t, float2_t>(float2_t x)
|
||||
{
|
||||
return static_cast<f4x2_pk_t>(type_convert<f4x2_t>(x));
|
||||
}
|
||||
|
||||
// convert vector of 32 fp32 to vector of 32 fp4
|
||||
template <>
|
||||
|
||||
@@ -84,6 +84,7 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
|
||||
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
|
||||
printf("K: %d\n", K);
|
||||
|
||||
for(size_t m = 0; m < M; m++)
|
||||
{
|
||||
@@ -95,15 +96,29 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
if(k % 2 == 1)
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
|
||||
f4_t(arg.a_m_k_(m, k / 2).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
else
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
|
||||
f4_t(arg.a_m_k_(m, k / 2).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
if(m == 0)
|
||||
{
|
||||
printf("a_m_k_scaled(%zu, %zu): %f = %f * %f\n",
|
||||
m,
|
||||
k,
|
||||
a_m_k_scaled(m, k),
|
||||
k % 2 == 1
|
||||
? type_convert<ComputeTypeA>(f4_t(
|
||||
arg.a_m_k_(m, k / 2).template unpack<>(Number<1>{})))
|
||||
: type_convert<ComputeTypeA>(f4_t(
|
||||
arg.a_m_k_(m, k / 2).template unpack<>(Number<0>{}))),
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -124,13 +139,13 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
if(k % 2 == 1)
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
|
||||
f4_t(arg.b_k_n_(k / 2, n).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
else
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
|
||||
f4_t(arg.b_k_n_(k / 2, n).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user