enable RDNA on elementwise add examples

- 64 to 32 wavefront size
- add GfxId as a template parameter to Run
This commit is contained in:
Philip Maybank
2025-07-21 09:56:02 -04:00
parent 42b2e3bc40
commit a70f4db370
3 changed files with 123 additions and 11 deletions

View File

@@ -26,7 +26,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
template <typename DataType, int GfxId>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType; // input data type
@@ -63,6 +63,40 @@ bool run(const ck_tile::ArgParser& arg_parser)
.data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy
x_buf_b.ToDevice(x_host_b.data());
// --- Device Properties Query (for logging and warpSize) ---
int deviceId;
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
std::cout << "Running on GPU: " << props.name << " (Architecture: " << props.gcnArchName << ")"
<< std::endl;
std::cout << "GfxId instantiated: " << GfxId << std::endl;
// These will hold the *values* for the ck_tile::sequence types
// They are initialized based on the GfxId
constexpr ck_tile::index_t selected_warp_tile = (GfxId == 1200) ? Gfx120x::WarpTile
: (GfxId == 900) ? Gfx90x::WarpTile
:
/* else */ Generic::WarpTile;
// Use if constexpr to select the compile-time constants for the current GfxId
bool fail = false;
if constexpr(GfxId == 1200)
{
std::cout << "Using gfx120x-optimized parameters (template specialization)." << std::endl;
}
else if constexpr(GfxId == 900)
{
std::cout << "Using gfx90x-optimized parameters (template specialization)." << std::endl;
}
else
{ // Fallback for GfxId == 0 or unknown
std::cerr << "WARNING: No specific parameters for GfxId " << GfxId
<< ". Using generic parameters." << std::endl;
return fail;
}
// Dividing the problem into blocktile, warptile, and vector
// The blocktile is the size of the tile that will be processed by a single thread block (also
// called work group) The warptile is the size of the tile that will be processed by a single
@@ -75,9 +109,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
// into blocks of this size)
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp
// will cover some part of blockTile)
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread
// will cover some part of WarpTile)
using WarpTile =
ck_tile::sequence<selected_warp_tile>; // How many elements are covered by a warp
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread
// will cover some part of WarpTile)
// Interpretation of above configurations
// Each thread will cover 1 element (Vector)
@@ -159,8 +194,19 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
int deviceId;
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
std::string arch_name = props.gcnArchName;
if(data_type == "fp16" && (arch_name.find("gfx12") != std::string::npos))
return run<ck_tile::half_t, 1200>(arg_parser) ? 0 : -2;
else if(data_type == "fp16" && (arch_name.find("gfx908") != std::string::npos))
return run<ck_tile::half_t, 900>(arg_parser) ? 0 : -2;
else
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
std::cerr << "Unsupported data type: " << data_type << std::endl;
return -1;
}
}

View File

@@ -17,7 +17,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
template <typename DataType, int GfxId>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
@@ -46,13 +46,48 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf_a.ToDevice(x_host_a.data());
x_buf_b.ToDevice(x_host_b.data());
// --- Device Properties Query (for logging and warpSize) ---
int deviceId;
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
std::cout << "Running on GPU: " << props.name << " (Architecture: " << props.gcnArchName << ")"
<< std::endl;
std::cout << "GfxId instantiated: " << GfxId << std::endl;
// These will hold the *values* for the ck_tile::sequence types
// They are initialized based on the GfxId
constexpr ck_tile::index_t selected_warp_tile = (GfxId == 1200) ? Gfx120x::WarpTile
: (GfxId == 900) ? Gfx90x::WarpTile
:
/* else */ Generic::WarpTile;
// Use if constexpr to select the compile-time constants for the current GfxId
bool fail = false;
if constexpr(GfxId == 1200)
{
std::cout << "Using gfx120x-optimized parameters (template specialization)." << std::endl;
}
else if constexpr(GfxId == 900)
{
std::cout << "Using gfx90x-optimized parameters (template specialization)." << std::endl;
}
else
{ // Fallback for GfxId == 0 or unknown
std::cerr << "WARNING: No specific parameters for GfxId " << GfxId
<< ". Using generic parameters." << std::endl;
return fail;
}
using BlockWarps =
ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads
// per warp, 512 threads in one block are NEEDED)
using BlockTile =
ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block)
using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one
// warp (64 threads))
using WarpTile = ck_tile::sequence<1, 8 * selected_warp_tile>; // shape of one warpTile
// (elements covered by one warp
// (32/64 threads))
using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread)
constexpr ck_tile::index_t kBlockSize =
@@ -107,8 +142,19 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
int deviceId;
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
std::string arch_name = props.gcnArchName;
if(data_type == "fp16" && (arch_name.find("gfx12") != std::string::npos))
return run<ck_tile::half_t, 1200>(arg_parser) ? 0 : -2;
else if(data_type == "fp16" && (arch_name.find("gfx908") != std::string::npos))
return run<ck_tile::half_t, 900>(arg_parser) ? 0 : -2;
else
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
std::cerr << "Unsupported data type: " << data_type << std::endl;
return -1;
}
}

View File

@@ -0,0 +1,20 @@
// ============================================================================
// Architecture-specific parameter definitions
// We'll define all parameters for all supported architectures here.
// ============================================================================
// Parameters for gfx120x (using a namespace for organization or just global constexpr)
namespace Gfx120x {
constexpr ck_tile::index_t WarpTile = 32;
}
// Parameters for gfx90x (example values, adjust as needed)
namespace Gfx90x {
constexpr ck_tile::index_t WarpTile = 64;
}
// Generic Parameters - should never be used in this example
// templated run function should only be instantiated for Gfx120x and Gfx90x
namespace Generic {
constexpr ck_tile::index_t WarpTile = -1;
}