diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c1ca789f5..ba57ead09a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -610,6 +610,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) PACKAGE_NAME examples ) add_subdirectory(example) + add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) endif() diff --git a/tile_engine/CMakeLists.txt b/tile_engine/CMakeLists.txt new file mode 100755 index 0000000000..cd1a192a74 --- /dev/null +++ b/tile_engine/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(BEFORE + ${CMAKE_CURRENT_LIST_DIR}/include + ) + +add_subdirectory(ops) diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt new file mode 100755 index 0000000000..0cf2c16da2 --- /dev/null +++ b/tile_engine/ops/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(gemm) diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt new file mode 100644 index 0000000000..d28017ca0c --- /dev/null +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -0,0 +1,45 @@ + + +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --list_blobs + RESULT_VARIABLE ret +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) + +add_custom_command( + OUTPUT ${GEMM_CODEGEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --gen_blobs + DEPENDS ${GEMM_CODEGEN_BLOBS} +) + +set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") +message("adding example ${EXECUTABLE_GEMM_INSTANCE}") + +# use build as include directory +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp) +target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS}) + +set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) + +list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress) + +target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS}) + +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json new file mode 100644 index 0000000000..e21197d1de --- /dev/null +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -0,0 +1,60 @@ +{ + + "layout_a": { + "values": ["r"] + }, + "layout_b": { + "values": ["c"] + }, + "layout_c": { + "values": ["r"] + }, + "datatype": { + "values": ["fp16"] + }, + "tile_m": { + "values": [256] + }, + "tile_n": { + "values": [256] + }, + "tile_k": { + "values": [64] + }, + "warp_m": { + "values": [2] + }, + "warp_n": { + "values": [2] + }, + "warp_k": { + "values": [1] + }, + "warp_tile_m": { + "values": [32] + }, + "warp_tile_n": { + "values": [32] + }, + "warp_tile_k": { + "values": [16] + }, + "kPadM": { + "values": [false] + }, + "kPadN": { + "values": [false] + }, + "kPadK": { + "values": [false] + }, + "pipeline": { + "values": ["compv3", "mem"] + }, + "scheduler": { + "values": ["intrawave", "interwave"] + }, + "epilogue": { + "values": ["default", "cshuffle"] + } +} diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp new file mode 100644 index 0000000000..508f634920 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "gemm_common.hpp" +#include "gemm_dispatcher.hpp" +#include "gemm_host_api.hpp" + +float gemm_kernel_launch(KernelTraits& trait, + ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) +{ + return GemmDispatcher::dispatch(trait, args, s); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + const ALayout a_layout = ALayout{}; + const BLayout b_layout = BLayout{}; + // const CLayout c_layout = CLayout{}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + int verify = arg_parser.get_int("v"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs gemm_args; + gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + gemm_args.k_batch = kbatch; + gemm_args.M = M; + gemm_args.N = N; + gemm_args.K = K; + gemm_args.stride_A = stride_A; + gemm_args.stride_B = stride_B; + gemm_args.stride_C = stride_C; + + KernelTraits trait; + trait.pipeline = arg_parser.get_str("pipeline"); + trait.scheduler = arg_parser.get_str("scheduler"); + trait.epilogue = arg_parser.get_str("epilogue"); + trait.kPadM = arg_parser.get_bool("pad_m"); + trait.kPadN = arg_parser.get_bool("pad_n"); + trait.kPadK = arg_parser.get_bool("pad_k"); + + float ave_time = gemm_kernel_launch( + trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + if(verify) + { + pass = gemm_verify( + verify, + a_m_k, + b_k_n, + c_m_n_dev_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch); + } + return pass; +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + return run(parser); + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp new file mode 100644 index 0000000000..4f0ea52a18 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -0,0 +1,287 @@ +#include + +#include +#include +#include +#include +#include +#include "ck_tile/ops/gemm.hpp" + +#pragma once + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +struct KernelTraits +{ + std::string pipeline; + std::string scheduler; + std::string epilogue; + bool kPadM; + bool kPadN; + bool kPadK; +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("split_k", "1", "splitK value") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("pipeline", "compv3", "compv3, compv4, mem") + .insert("scheduler", "intrawave", "intrawave, interwave") + .insert("epilogue", "cshuffle", "cshuffle, default") + .insert("pad_m", "false", "true, false") + .insert("pad_n", "false", "true, false") + .insert("pad_k", "false", "true, false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} + +// verification code +template +bool gemm_verify(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch) +{ + bool pass = true; + if(verify == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); + ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); + ck_tile::hip_check_error( + hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + a_m_k.get_element_space_size_in_bytes(), + hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + b_k_n.get_element_space_size_in_bytes(), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + c_m_n_dev_result.get_element_space_size_in_bytes(), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py new file mode 100755 index 0000000000..c0dad03ef0 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -0,0 +1,596 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Dict, Any +import functools +import itertools +import copy +import json +from dataclasses import dataclass + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::half_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t', + 'bf8' : 'ck_tile::bf8_t', + 'int4' : 'ck_tile::pk_int4_t' + } + +LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + +DEFAULT_EPILOGUE = """ + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; +""" + +CSHUFFLE_EPILOGUE = """ + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; +""" +HOT_LOOP_FALSE = """ + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); + } +""" +RUN_MEM = """ + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers"); + } +""" + +RUN_COMPV3 = """ + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); + } +""" + +RUN_COMPV4 = """ + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +""" + + +PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], + 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], + 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} + +SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', + 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} + +EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, + 'cshuffle' : CSHUFFLE_EPILOGUE} + +HOT_LOOP_TRUE = {'mem' : RUN_MEM, + 'compv3' : RUN_COMPV3, + 'compv4' : RUN_COMPV4} + + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +@dataclass +class GemmConfig: + def __init__(self, config_data): + self.matrix_cfg : Dict[str, Any] = {} + self.impl_cfg : Dict[str, Any] = {} + for key, value in config_data.items(): + if key in ["datatype", "layout_a", "layout_b", "layout_c"]: + self.matrix_cfg[key] = value + else: + self.impl_cfg[key] = value + + @property + def datatype(self) -> str: + return self.matrix_cfg["datatype"]["values"][0] + + @property + def layouts(self) -> List[str]: + return [ + self.matrix_cfg["layout_a"]["values"][0], + self.matrix_cfg["layout_b"]["values"][0], + self.matrix_cfg["layout_c"]["values"][0] + ] + + +class GemmCodeGenerator: + def __init__(self, output_dir: str, config: GemmConfig): + self.output_dir = Path(output_dir) + if not self.output_dir.exists(): + self.output_dir.mkdir() + + self.config = config + self.all_kernels = [] + self.unique_configs = [] + # Validate configurations + self._validate_config() + + def _validate_config(self): + """Validate matrix and implementation configurations""" + # Matrix config validation + for param in ["datatype", "layout_a", "layout_b", "layout_c"]: + if len(self.config.matrix_cfg[param]["values"]) != 1: + raise ValueError(f"Matrix config {param} must have exactly one value") + + # Implementation traits validation + required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k", + "warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline", + "epilogue", "scheduler", "kPadM", "kPadN", "kPadK"] + for param in required_params: + if not self.config.impl_cfg.get(param, {}).get("values"): + raise ValueError(f"Missing implementation parameter: {param}") + + def list_all(self): + """List all possible kernel configurations""" + w_p = Path(self.output_dir) + list_p = w_p / 'gemm_instance_blobs.txt' + self._list_config_groups() + with list_p.open('w') as list_f: + list_f.write(str(w_p / ("gemm_common.hpp")) + "\n") + list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n") + list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n") + for group in self.all_kernels: + list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n") + + + + def _list_config_groups(self): + params = [ + ("pipeline", "pipeline"), + ("epilogue", "epilogue"), + ("scheduler", "scheduler"), + ("kPadM", "kPadM"), + ("kPadN", "kPadN"), + ("kPadK", "kPadK") + ] + + # Generate all unique_combinations + _unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params])) + for combo in _unique: + config = {name: value for (_, name), value in zip(params, combo)} + pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values() + # To remove some unsupported combinations + unsupported_combination = [("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave")] + if (pipeline, epilogue, scheduler) not in unsupported_combination: + group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" + self.all_kernels.append(group_name) + self.unique_configs.append(config) + + def generate_all(self): + self._generate_common_header() + self._generate_config_groups() + self._generate_dispatcher() + + + def _generate_common_header(self): + """Generate common header with datatypes and layout""" + ctype = self.config.datatype + atype = self.config.datatype + btype = self.config.datatype + if self.config.datatype in ['fp8', 'bf8']: + ctype = 'fp16' + elif self.config.datatype in ['int4']: + atype = 'fp16' + ctype = 'fp16' + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" + +// Data types +using ADataType = {DATA_TYPE_MAP[atype]}; +using BDataType = {DATA_TYPE_MAP[btype]}; +using AccDataType = float; +using CDataType = {DATA_TYPE_MAP[ctype]}; + +// Layout configurations +using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; +using BLayout = {LAYOUT_MAP[self.config.layouts[1]]}; +using CLayout = {LAYOUT_MAP[self.config.layouts[2]]}; +""" + + + (self.output_dir / "gemm_common.hpp").write_text(content) + + def _generate_config_groups(self): + """Generate implementation configuration groups""" + if not self.unique_configs: # Check if the list is empty + self._list_config_groups() + for config in self.unique_configs: + self._generate_config_group(**config) + self.generate_common_instances_header() + + + def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str, + kPadM: bool, kPadN: bool, kPadK: bool): + """Generate a configuration group with all tile/warp combinations""" + group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" + filename = f"gemm_{group_name}.hpp" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_common.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/host.hpp" + +namespace {group_name} {{ +""" + # Add template struct with configuration + content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK) + + content += f"\n}} // namespace {group_name}\n" + (self.output_dir / filename).write_text(content) + + def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str, + kPadM: bool, kPadN: bool, kPadK: bool) -> str: + """Generate kernel struct template""" + return f""" +template +struct GemmKernel {{ + static constexpr bool kPadM = {BOOL_MAP(kPadM)}; + static constexpr bool kPadN = {BOOL_MAP(kPadN)}; + static constexpr bool kPadK = {BOOL_MAP(kPadK)}; + + static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ + static constexpr bool permuteA = false; + static constexpr bool permuteB = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; + + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; + {EPILOGUE_MAP[epilogue]} + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + if(s.log_level_ > 0) + {{ + std::cout << "Launching kernel with args:" + << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }} + + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + return ave_time; + + }}; + + if(has_hot_loop) {{ + {HOT_LOOP_TRUE[pipeline]} + }} else {{ + {HOT_LOOP_FALSE} + }} + + return ave_time; + }} +}}; +""" + + def generate_common_instances_header(self): + """Generate common instances header""" + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +""" + for group in self.all_kernels: + content += f"#include \"gemm_{group}.hpp\"\n" + (self.output_dir / "gemm_instances.hpp").write_text(content) + + def _generate_dispatcher(self): + """Generate dispatch mechanism""" + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_common.hpp" +#include "gemm_instances.hpp" +#include "gemm_host_api.hpp" +#include +#include +#include + +struct GemmDispatcher { + static auto& get_kernel_map() { + // Use a static local variable + static std::unordered_map> kernel_map; + return kernel_map; + } + + static void init() { + auto& kernel_map = get_kernel_map(); + if(!kernel_map.empty()) return; + \n""" + # Add tile/warp instantiations + tile_params = set(itertools.product( + self.config.impl_cfg["tile_m"]["values"], + self.config.impl_cfg["tile_n"]["values"], + self.config.impl_cfg["tile_k"]["values"], + self.config.impl_cfg["warp_m"]["values"], + self.config.impl_cfg["warp_n"]["values"], + self.config.impl_cfg["warp_k"]["values"], + self.config.impl_cfg["warp_tile_m"]["values"], + self.config.impl_cfg["warp_tile_n"]["values"], + self.config.impl_cfg["warp_tile_k"]["values"] + )) + + + for group in self.all_kernels: + content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) {{ + std::vector results;""" + for tile in tile_params: + # Check if we have valid tile/warp combinations + # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m + if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ + ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): + continue + content += f""" + //we can have multiple tiles config for the one kernel_trait + return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);""" + content += """ + };\n""" + + content += """ } + + + static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + const ck_tile::stream_config& s) { + init(); + const std::string key = assemble_key(trait); + auto& kernel_map = get_kernel_map(); + if(auto it = kernel_map.find(key); it != kernel_map.end()) { + return it->second(gemm_args, s); //Running single instance + } + throw std::runtime_error("No suitable kernel found: " + key); + } + +private: + static std::string assemble_key(const KernelTraits &trait) { + return std::string(trait.pipeline) + "_" + + trait.epilogue + "_" + + trait.scheduler + "_" + + "pad_" + + (trait.kPadM ? "true" : "false") + "_" + + (trait.kPadN ? "true" : "false") + "_" + + (trait.kPadK ? "true" : "false"); + } +}; + +""" + (self.output_dir / "gemm_dispatcher.hpp").write_text(content) + + +def do_list_blobs(args, gemm_config): + generator = GemmCodeGenerator(args.working_path, gemm_config) + generator.list_all() + +def do_gen_blobs(args, gemm_config): + generator = GemmCodeGenerator(args.working_path, gemm_config) + generator.generate_all() + + + +def main(args): + # Read and validate json file + with open(args.json, 'r') as json_file: + config_data = json.load(json_file) + + # Validate and parse configuration + gemm_config = GemmConfig(config_data) + + if args.list_blobs: + do_list_blobs(args, gemm_config) + elif args.gen_blobs: + do_gen_blobs(args, gemm_config) + else: + # If neither was specified, either do nothing or default to gen_blobs + print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...") + do_gen_blobs(args, gemm_config) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK gemm kernel", + ) + parser.add_argument( + "-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated" + ) + parser.add_argument( + "-j", "--json", required=True, help="Path to the json which contains the kernel configurations" + ) + parser.add_argument( + "-l", "--list_blobs", action = 'store_true', help="List all kernel to file" + ) + parser.add_argument( + "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files" + ) + + args = parser.parse_args() + + main(args)