diff --git a/CMakeLists.txt b/CMakeLists.txt index eaed7d3509..4ea1253752 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -162,7 +162,7 @@ execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMI configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h) set(ROCM_SYMLINK_LIBS OFF) -find_package(ROCM REQUIRED PATHS /opt/rocm) +find_package(ROCM REQUIRED PATHS /opt/venv/lib/python3.12/site-packages/_rocm_sdk_devel) include(ROCMInstallTargets) include(ROCMPackageConfigHelpers) diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.cpp b/example/ck_tile/42_mx_gemm/mx_gemm.cpp index f6c7c1c758..ca76be407e 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.cpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.cpp @@ -31,7 +31,7 @@ template + bool UsePersistentKernel = true> float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_dev_buf, ck_tile::DeviceMem& c_dev_buf, @@ -83,7 +83,7 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, GemmConfig::UseStructuredSparsity, UsePersistentKernel, GemmConfig::NumWaveGroups, - true>; + false>; using MXPipelineProblem = MXGemmPipelineProblem }; // GEMM config with 16x16 warp tile -struct MXfp4_GemmConfig16 + +struct MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 256; @@ -70,3 +71,17 @@ struct MXfp4_GemmConfig16 static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; +struct MXfp4_GemmConfig16 : MxGemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; +}; + +// GEMM config with 16x16 warp tile +struct MXfp8_GemmConfig16 : MxGemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; +}; diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index fdaa57fa7b..11f687a6ef 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -49,25 +49,25 @@ int run_mx_gemm_with_layouts(int argc, // Scale tensors // Assuming block scale 32 - ck_tile::index_t scale_n_size = N / 32; + using ScaleType = ck_tile::e8m0_t; ck_tile::index_t scale_k_size = K / 32; - ck_tile::HostTensor scale_a_host( + ck_tile::HostTensor scale_a_host( ck_tile::HostTensorDescriptor({M, scale_k_size}, {scale_k_size, 1})); - ck_tile::HostTensor scale_b_host( - ck_tile::HostTensorDescriptor({scale_k_size, scale_n_size}, {scale_n_size, 1})); + ck_tile::HostTensor scale_b_host( + ck_tile::HostTensorDescriptor({scale_k_size, N}, {1, scale_k_size})); switch(init_method) { case 0: ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(scale_a_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(scale_b_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(scale_a_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(scale_b_host); break; case 1: ck_tile::FillConstant{ADataType(1.f)}(a_host); ck_tile::FillConstant{BDataType(1.f)}(b_host); - ck_tile::FillConstant{ck_tile::e8m0_t(1.f)}(scale_a_host); - ck_tile::FillConstant{ck_tile::e8m0_t(1.f)}(scale_b_host); + ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); + ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); break; } @@ -83,8 +83,8 @@ int run_mx_gemm_with_layouts(int argc, scale_b_dev_buf.ToDevice(scale_b_host.data()); // Scale pointers - using ScaleM = ck_tile::MXScalePointer<1, 32>; // per-token - using ScaleN = ck_tile::MXScalePointer<32, 32>; // per-block + using ScaleM = ck_tile::MXScalePointer<1, 32>; // in blocks of 32 in K + using ScaleN = ck_tile::MXScalePointer<1, 32>; ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); @@ -104,14 +104,31 @@ int run_mx_gemm_with_layouts(int argc, (void)ave_time; + bool pass = true; if(validation > 0) { + // get output data from device c_dev_buf.FromDevice(c_host.data()); - // TODO: Implement validation logic (reference GEMM with scales) - // For now just print success if it runs - std::cout << "Validation not implemented yet." << std::endl; + + // compute reference + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_mx_gemm( + a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host); + + const float rtol = std::is_same_v ? 1e-3 : 1e-2; + const float atol = std::is_same_v ? 1e-3 : 1e-2; + + pass = ck_tile::check_err( + c_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol); + + std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; } - return 0; + return pass ? 0 : -1; } int run_mx_gemm_example(int argc, char* argv[]) @@ -126,24 +143,28 @@ int run_mx_gemm_example(int argc, char* argv[]) std::string mx_prec = arg_parser.get_str("mx_prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - int persistent_opt = arg_parser.get_int("persistent"); if(a_layout == "R" && b_layout == "C") { if(mx_prec == "fp4" || mx_prec == "fp4xfp4") { - if(persistent_opt == 0) - return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); - else - throw std::runtime_error("Only non-persistent kernels are supported currently!"); + return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") + { + return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else { - throw std::runtime_error("Only fp4xfp4 is supported currently!"); + throw std::runtime_error("Only fp4 and fp8 is supported currently!"); } } else diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 08d555d27c..7830749efb 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -39,12 +39,8 @@ #define CK_TILE_DEVICE inline __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_DEVICE_EXTERN __device__ -#if __clang_major__ < 22 #define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__ #else -#define CK_TILE_HOST_DEVICE_EXTERN -#endif -#else #define CK_TILE_HOST inline #define CK_TILE_DEVICE inline #define CK_TILE_HOST_DEVICE inline diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index e188ddec61..5703983d30 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -119,6 +119,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { + // TODO: these could be replaced by the standard UniversalGEMM tile distributions?? constexpr index_t K2 = AK1; // f4=32; f8=16 constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index 0ade057bcb..8a0ce78762 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -95,6 +95,8 @@ struct MXGemmKernel : UniversalGemmKernel::PackedSize; static constexpr auto BPackedSize = numeric_traits::PackedSize; + /// @brief The e8m0 scales are packed into int32/float32 such that + /// in one element contains a 2x2 block of scales (two rows, two lements in K dim) static constexpr auto MXdlPack = MXGemmPipeline::MXdlPack; static constexpr auto NXdlPack = MXGemmPipeline::NXdlPack; static constexpr auto KXdlPack = MXGemmPipeline::KXdlPack; @@ -195,7 +197,8 @@ struct MXGemmKernel : UniversalGemmKernel{}, sequence<1, 2>{}), @@ -251,12 +254,14 @@ struct MXGemmKernel : UniversalGemmKernel{}, number{}), {i_m / MXdlPack, 0}); + // We are packing 2x2 (NXdlPack x KXdlPack) scales (e8m0) into one int32 element auto scale_b_block_window = make_tile_window( views.at(I5), make_tuple(number{}, @@ -295,7 +300,7 @@ struct MXGemmKernel : UniversalGemmKernel( - scale_m_ptr_offset.ptr, - make_tuple(number{}, number{}), - make_tuple(number<1>{}, number<0>{}), - number<1>{}, - number<1>{} - ); - } else { - return typename EpiloguePipeline::EmptyScale{}; - } - }(); - - auto scale_n_view = [&]() { - if constexpr (ScaleN::GranularityMN != -1) { - return make_naive_tensor_view( - scale_n_ptr_offset.ptr, - make_tuple(number{}, number{}), - make_tuple(number<0>{}, number<1>{}), - number<1>{}, - number<1>{} - ); - } else { - return typename EpiloguePipeline::EmptyScale{}; - } - }(); - - EpiloguePipeline{}(c_block_window, - c_block_tile, - d_block_window, - smem_ptr_ping, - scale_m_view, - scale_n_view); - } - else if(UseDefaultScheduler || (get_warp_id() == 0)) - { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); - } + auto& c_block_window = gemm_tile_windows.at(I3); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize() diff --git a/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1.hpp b/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1.hpp index 619f80f5f7..dda2d02d7f 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1.hpp @@ -242,7 +242,7 @@ struct MXGemmPipelineAgBgCrV1 move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); }; - // Helper for Math Loop + // Helper for Main Loop auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) { // Define register tiles types for double buffering using AValType = decltype(load_tile_with_offset(a_warp_window, tuple, number<0>>{})); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1_policy.hpp index 441e7d71be..c688a5e826 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1_policy.hpp @@ -227,31 +227,10 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence>, tuple, sequence<2, 1>>, tuple, sequence<1, 2>>, - sequence<2, 2>, + sequence<2, 2>, // K_Thread/AK1, AK1 sequence<0, 2>>{}); } - CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution() - { - constexpr index_t K2 = BK1; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 - - constexpr index_t N2 = WaveSize / K1; // 8 - constexpr index_t N1 = BlockSize / WaveSize; // 4 - constexpr index_t N0 = NPerBlock / (N2 * N1); - static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!"); - static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); - - return make_static_tile_distribution( - tile_distribution_encoding< // - sequence<1>, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, // N0,K0,K2 - sequence<0, 0, 2>>{}); - } template CK_TILE_DEVICE static constexpr auto @@ -294,6 +273,29 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy TensorView::DstInMemOp>{naive_view.buf_, desc}; } + CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution() + { + // TODO: these could be replaced by the standard UniversalGEMM tile distributions?? + constexpr index_t K2 = BK1; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + + constexpr index_t N2 = WaveSize / K1; // 8 + constexpr index_t N1 = BlockSize / WaveSize; // 4 + constexpr index_t N0 = NPerBlock / (N2 * N1); + static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!"); + static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, // N0,K0,K2 + sequence<0, 0, 2>>{}); + } + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor() { constexpr index_t K2 = BK1; // f4=32; f8=16