diff --git a/example/ck_tile/99_toy_example/00_add_vector/add_vector.cpp b/example/ck_tile/99_toy_example/00_add_vector/add_vector.cpp index ef5c2afae4..4431aadf71 100644 --- a/example/ck_tile/99_toy_example/00_add_vector/add_vector.cpp +++ b/example/ck_tile/99_toy_example/00_add_vector/add_vector.cpp @@ -26,7 +26,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { using XDataType = DataType; // input data type @@ -63,6 +63,40 @@ bool run(const ck_tile::ArgParser& arg_parser) .data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy x_buf_b.ToDevice(x_host_b.data()); + // --- Device Properties Query (for logging and warpSize) --- + int deviceId; + HIP_CHECK_ERROR(hipGetDevice(&deviceId)); + hipDeviceProp_t props; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId)); + + std::cout << "Running on GPU: " << props.name << " (Architecture: " << props.gcnArchName << ")" + << std::endl; + std::cout << "GfxId instantiated: " << GfxId << std::endl; + + // These will hold the *values* for the ck_tile::sequence types + // They are initialized based on the GfxId + constexpr ck_tile::index_t selected_warp_tile = (GfxId == 1200) ? Gfx120x::WarpTile + : (GfxId == 900) ? Gfx90x::WarpTile + : + /* else */ Generic::WarpTile; + + // Use if constexpr to select the compile-time constants for the current GfxId + bool fail = false; + if constexpr(GfxId == 1200) + { + std::cout << "Using gfx120x-optimized parameters (template specialization)." << std::endl; + } + else if constexpr(GfxId == 900) + { + std::cout << "Using gfx90x-optimized parameters (template specialization)." << std::endl; + } + else + { // Fallback for GfxId == 0 or unknown + std::cerr << "WARNING: No specific parameters for GfxId " << GfxId + << ". Using generic parameters." << std::endl; + return fail; + } + // Dividing the problem into blocktile, warptile, and vector // The blocktile is the size of the tile that will be processed by a single thread block (also // called work group) The warptile is the size of the tile that will be processed by a single @@ -75,9 +109,10 @@ bool run(const ck_tile::ArgParser& arg_parser) // into blocks of this size) using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp // will cover some part of blockTile) - using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp - using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread - // will cover some part of WarpTile) + using WarpTile = + ck_tile::sequence; // How many elements are covered by a warp + using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread + // will cover some part of WarpTile) // Interpretation of above configurations // Each thread will cover 1 element (Vector) @@ -159,8 +194,19 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + int deviceId; + HIP_CHECK_ERROR(hipGetDevice(&deviceId)); + hipDeviceProp_t props; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId)); + std::string arch_name = props.gcnArchName; + + if(data_type == "fp16" && (arch_name.find("gfx12") != std::string::npos)) + return run(arg_parser) ? 0 : -2; + else if(data_type == "fp16" && (arch_name.find("gfx908") != std::string::npos)) + return run(arg_parser) ? 0 : -2; + else { - return run(arg_parser) ? 0 : -2; + std::cerr << "Unsupported data type: " << data_type << std::endl; + return -1; } } diff --git a/example/ck_tile/99_toy_example/01_add/add.cpp b/example/ck_tile/99_toy_example/01_add/add.cpp index 3955fb2883..70c0de6efa 100644 --- a/example/ck_tile/99_toy_example/01_add/add.cpp +++ b/example/ck_tile/99_toy_example/01_add/add.cpp @@ -17,7 +17,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { using XDataType = DataType; @@ -46,13 +46,48 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf_a.ToDevice(x_host_a.data()); x_buf_b.ToDevice(x_host_b.data()); + // --- Device Properties Query (for logging and warpSize) --- + int deviceId; + HIP_CHECK_ERROR(hipGetDevice(&deviceId)); + hipDeviceProp_t props; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId)); + + std::cout << "Running on GPU: " << props.name << " (Architecture: " << props.gcnArchName << ")" + << std::endl; + std::cout << "GfxId instantiated: " << GfxId << std::endl; + + // These will hold the *values* for the ck_tile::sequence types + // They are initialized based on the GfxId + constexpr ck_tile::index_t selected_warp_tile = (GfxId == 1200) ? Gfx120x::WarpTile + : (GfxId == 900) ? Gfx90x::WarpTile + : + /* else */ Generic::WarpTile; + + // Use if constexpr to select the compile-time constants for the current GfxId + bool fail = false; + if constexpr(GfxId == 1200) + { + std::cout << "Using gfx120x-optimized parameters (template specialization)." << std::endl; + } + else if constexpr(GfxId == 900) + { + std::cout << "Using gfx90x-optimized parameters (template specialization)." << std::endl; + } + else + { // Fallback for GfxId == 0 or unknown + std::cerr << "WARNING: No specific parameters for GfxId " << GfxId + << ". Using generic parameters." << std::endl; + return fail; + } + using BlockWarps = ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads // per warp, 512 threads in one block are NEEDED) using BlockTile = ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block) - using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one - // warp (64 threads)) + using WarpTile = ck_tile::sequence<1, 8 * selected_warp_tile>; // shape of one warpTile + // (elements covered by one warp + // (32/64 threads)) using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread) constexpr ck_tile::index_t kBlockSize = @@ -107,8 +142,19 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + int deviceId; + HIP_CHECK_ERROR(hipGetDevice(&deviceId)); + hipDeviceProp_t props; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId)); + std::string arch_name = props.gcnArchName; + + if(data_type == "fp16" && (arch_name.find("gfx12") != std::string::npos)) + return run(arg_parser) ? 0 : -2; + else if(data_type == "fp16" && (arch_name.find("gfx908") != std::string::npos)) + return run(arg_parser) ? 0 : -2; + else { - return run(arg_parser) ? 0 : -2; + std::cerr << "Unsupported data type: " << data_type << std::endl; + return -1; } } diff --git a/include/ck_tile/core/arch/warp_tile_size.hpp b/include/ck_tile/core/arch/warp_tile_size.hpp new file mode 100644 index 0000000000..c3751da219 --- /dev/null +++ b/include/ck_tile/core/arch/warp_tile_size.hpp @@ -0,0 +1,20 @@ +// ============================================================================ +// Architecture-specific parameter definitions +// We'll define all parameters for all supported architectures here. +// ============================================================================ + +// Parameters for gfx120x (using a namespace for organization or just global constexpr) +namespace Gfx120x { +constexpr ck_tile::index_t WarpTile = 32; +} + +// Parameters for gfx90x (example values, adjust as needed) +namespace Gfx90x { +constexpr ck_tile::index_t WarpTile = 64; +} + +// Generic Parameters - should never be used in this example +// templated run function should only be instantiated for Gfx120x and Gfx90x +namespace Generic { +constexpr ck_tile::index_t WarpTile = -1; +}