diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index b649215bc6..d42ee8558b 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -133,13 +133,15 @@ void preShuffleScaleBuffer(const ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int int k0 = k / (XdlKThread * KXdlPack); // i KRepeat int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk / KXdlPack; // i XdlKThread - int k2 = tempk % KXdlPack; // i KXdlPack + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; + k1 * MNXdlPack * KXdlPack * XdlMNThread + + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + + n2; dst[outputIndex] = src[n * K + k]; } @@ -332,6 +334,27 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "NOTE: No input data initialization." << std::endl; } } + printf("a_scale:\n"); + for (size_t i = 0; i < M; i++) + { + for (size_t j = 0; j < K / ScaleBlockSize; j++) + { + a_m_k_scale(i, j) = ck::type_convert(static_cast(powf(2.0f, (j/4)%4))); + printf("%02x ", *reinterpret_cast(&a_m_k_scale(i, j))); + } + printf("\n"); + } + printf("b_scale:\n"); + for (size_t i = 0; i < N; i++) + { + for (size_t j = 0; j < K / ScaleBlockSize; j++) + { + b_k_n_scale(j, i) = ck::type_convert(static_cast(powf(2.0f, (j/4)%4))); + printf("%02x ", *reinterpret_cast(&b_k_n_scale(j, i))); + } + printf("\n"); + } + #if 1 preShuffleScaleBuffer( a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize); @@ -339,6 +362,22 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); #endif + printf("a_shuffled_scale:\n"); + for (size_t i = 0; i < M*K / ScaleBlockSize; i++) + { + printf("%02x ", *reinterpret_cast(&(a_shuffled_scale.mData.data()[i]))); + + if(i % 64 == 63) + printf("\n"); + } + printf("b_shuffled_scale:\n"); + for (size_t i = 0; i < N*K / ScaleBlockSize; i++) + { + printf("%02x ", *reinterpret_cast(&(b_shuffled_scale.mData.data()[i]))); + if(i % 64 == 63) + printf("\n"); + } + if(config.verbosity > 0) std::cout << "Device memory allocation..." << std::endl; DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); @@ -353,14 +392,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); - // for (size_t i = 0; i < N; i++) - // { - // for (size_t j = 0; j < K / ScaleBlockSize; j++) - // { - // printf("%02x ", *reinterpret_cast(&b_shuffled_scale(j, i))); - // } - // printf("\n"); - // } if(config.verbosity > 0) std::cout << "Done." << std::endl; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp index b8b248c23d..7bb0643f46 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -367,19 +367,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx(&a_scale_thread_bufs(I0)[Number<0>{}]), *reinterpret_cast(&a_scale_thread_bufs(I0)[Number<1>{}]), *reinterpret_cast(&a_scale_thread_bufs(I0)[Number<2>{}]), *reinterpret_cast(&a_scale_thread_bufs(I0)[Number<3>{}])); } - +#endif // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, 0, 0)); + a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); // Prefetch b_scales static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { @@ -396,20 +397,21 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx(&b_scale_thread_bufs(I0)[Number<0>{}]), *reinterpret_cast(&b_scale_thread_bufs(I0)[Number<1>{}]), *reinterpret_cast(&b_scale_thread_bufs(I0)[Number<2>{}]), *reinterpret_cast(&b_scale_thread_bufs(I0)[Number<3>{}])); } - +#endif // restore col id and advance to the next set of scales // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, 0, 0)); + b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); @@ -483,7 +485,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto n0) { @@ -504,7 +506,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { @@ -710,7 +705,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx(&a_scale_thread_bufs(I1)[Number<0>{}]), + *reinterpret_cast(&a_scale_thread_bufs(I1)[Number<1>{}]), + *reinterpret_cast(&a_scale_thread_bufs(I1)[Number<2>{}]), + *reinterpret_cast(&a_scale_thread_bufs(I1)[Number<3>{}])); + } + if(get_thread_local_1d_id()) + { + printf("2stGMEM Tid: %03d, Scale B: %02x %02x %02x %02x\n", + get_thread_local_1d_id(), + *reinterpret_cast(&b_scale_thread_bufs(I1)[Number<0>{}]), + *reinterpret_cast(&b_scale_thread_bufs(I1)[Number<1>{}]), + *reinterpret_cast(&b_scale_thread_bufs(I1)[Number<2>{}]), + *reinterpret_cast(&b_scale_thread_bufs(I1)[Number<3>{}])); + } +#endif block_sync_lds(); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -796,7 +811,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}], + c_thread_buf[Number<1>{}], + c_thread_buf[Number<2>{}], + c_thread_buf[Number<3>{}]); + } +#endif block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { @@ -911,6 +936,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}], + c_thread_buf[Number<1>{}], + c_thread_buf[Number<2>{}], + c_thread_buf[Number<3>{}]); + } +#endif } else if constexpr(TailNum == TailNumber::Odd) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index f71f47d0be..56c5b25d32 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -424,6 +424,15 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX; + Run(kernel); + } } return ave_time; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 1927eae105..586d99cd64 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -1279,7 +1279,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 } } } - +#if 0 // check gridwise gemm pipeline const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); @@ -1290,7 +1290,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 return false; } } - +#endif // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -1518,9 +1518,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // mfma.selected_mfma.num_threads_per_blk; // A wave access continuous memory - auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize; + auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; - auto a_thread_offset_m = waveId_m * MPerXdl * MXdlPack; + auto a_thread_offset_m = waveId_m; auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2 } }; +__host__ __device__ inline constexpr int32_t get_exponent_value_ex(e8m0x4_bexp_t x) +{ + return ( + (static_cast(x.template AsType()[Number<0>{}].data) ) | + (static_cast(x.template AsType()[Number<1>{}].data) << 8) | + (static_cast(x.template AsType()[Number<2>{}].data) << 16 ) | + (static_cast(x.template AsType()[Number<3>{}].data) << 24)); +} + template <> struct mfma_type { @@ -899,20 +908,18 @@ struct mfma_type const ScaleB& scale_b, FloatC& reg_c) const { - if(get_thread_local_1d_id() == 0) - { - printf("Before BitCast: Scale A: %08x, Scale B: %08x\n", - *reinterpret_cast(&scale_a), - *reinterpret_cast(&scale_b)); - } + // if(get_thread_local_1d_id() == 0) + // { + // printf("Before BitCast: Scale A: %08x, Scale B: %08x\n", + // *reinterpret_cast(&scale_a), + // *reinterpret_cast(&scale_b)); + // } // static_assert(scalar_type::vector_size == 1, "Expect single scale at this // point."); static_assert(scalar_type::vector_size == 1, "Expect single scale at // this point."); - // intrin_mfma_scale_f32_16x16x128f8f6f4::Run( - // a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); intrin_mfma_scale_f32_16x16x128f8f6f4::Run( - a, bit_cast(scale_a), b, bit_cast(scale_b), reg_c); + a, get_exponent_value_ex(scale_a), b, get_exponent_value_ex(scale_b), reg_c); } }; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index f4ccf9d232..0be477077e 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -756,11 +756,13 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> const int32_t scale_b, FloatC& reg_c) { - // if(get_thread_local_1d_id()){ - // printf("Scale A: %08x, Scale B: %08x\n", - // *reinterpret_cast(&scale_a), *reinterpret_cast(&scale_b)); - // } + if(get_thread_local_1d_id()){ + printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n", + get_thread_local_1d_id(), + *reinterpret_cast(&scale_a), *reinterpret_cast(&scale_b), + OpselA, OpselB); + } #if defined(__gfx950__) int32x4_t arg_a = bit_cast(reg_a); int32x4_t arg_b = bit_cast(reg_b); diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 66999f87b6..d0c91c4e56 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -2134,6 +2134,7 @@ using f6x32_t = typename vector_type::type; using bf6x16_t = typename vector_type::type; using bf6x32_t = typename vector_type::type; +using e8m0x4_bexp_t = typename vector_type::type; // pack int4 using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type;