mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
enable RDNA on elementwise add examples
- 64 to 32 wavefront size - add GfxId as a template parameter to Run
This commit is contained in:
@@ -26,7 +26,7 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, int GfxId>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType; // input data type
|
||||
@@ -63,6 +63,40 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
.data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
// --- Device Properties Query (for logging and warpSize) ---
|
||||
int deviceId;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
|
||||
|
||||
std::cout << "Running on GPU: " << props.name << " (Architecture: " << props.gcnArchName << ")"
|
||||
<< std::endl;
|
||||
std::cout << "GfxId instantiated: " << GfxId << std::endl;
|
||||
|
||||
// These will hold the *values* for the ck_tile::sequence types
|
||||
// They are initialized based on the GfxId
|
||||
constexpr ck_tile::index_t selected_warp_tile = (GfxId == 1200) ? Gfx120x::WarpTile
|
||||
: (GfxId == 900) ? Gfx90x::WarpTile
|
||||
:
|
||||
/* else */ Generic::WarpTile;
|
||||
|
||||
// Use if constexpr to select the compile-time constants for the current GfxId
|
||||
bool fail = false;
|
||||
if constexpr(GfxId == 1200)
|
||||
{
|
||||
std::cout << "Using gfx120x-optimized parameters (template specialization)." << std::endl;
|
||||
}
|
||||
else if constexpr(GfxId == 900)
|
||||
{
|
||||
std::cout << "Using gfx90x-optimized parameters (template specialization)." << std::endl;
|
||||
}
|
||||
else
|
||||
{ // Fallback for GfxId == 0 or unknown
|
||||
std::cerr << "WARNING: No specific parameters for GfxId " << GfxId
|
||||
<< ". Using generic parameters." << std::endl;
|
||||
return fail;
|
||||
}
|
||||
|
||||
// Dividing the problem into blocktile, warptile, and vector
|
||||
// The blocktile is the size of the tile that will be processed by a single thread block (also
|
||||
// called work group) The warptile is the size of the tile that will be processed by a single
|
||||
@@ -75,9 +109,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// into blocks of this size)
|
||||
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp
|
||||
// will cover some part of blockTile)
|
||||
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
|
||||
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread
|
||||
// will cover some part of WarpTile)
|
||||
using WarpTile =
|
||||
ck_tile::sequence<selected_warp_tile>; // How many elements are covered by a warp
|
||||
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread
|
||||
// will cover some part of WarpTile)
|
||||
|
||||
// Interpretation of above configurations
|
||||
// Each thread will cover 1 element (Vector)
|
||||
@@ -159,8 +194,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
int deviceId;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
|
||||
std::string arch_name = props.gcnArchName;
|
||||
|
||||
if(data_type == "fp16" && (arch_name.find("gfx12") != std::string::npos))
|
||||
return run<ck_tile::half_t, 1200>(arg_parser) ? 0 : -2;
|
||||
else if(data_type == "fp16" && (arch_name.find("gfx908") != std::string::npos))
|
||||
return run<ck_tile::half_t, 900>(arg_parser) ? 0 : -2;
|
||||
else
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
std::cerr << "Unsupported data type: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, int GfxId>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
@@ -46,13 +46,48 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
// --- Device Properties Query (for logging and warpSize) ---
|
||||
int deviceId;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
|
||||
|
||||
std::cout << "Running on GPU: " << props.name << " (Architecture: " << props.gcnArchName << ")"
|
||||
<< std::endl;
|
||||
std::cout << "GfxId instantiated: " << GfxId << std::endl;
|
||||
|
||||
// These will hold the *values* for the ck_tile::sequence types
|
||||
// They are initialized based on the GfxId
|
||||
constexpr ck_tile::index_t selected_warp_tile = (GfxId == 1200) ? Gfx120x::WarpTile
|
||||
: (GfxId == 900) ? Gfx90x::WarpTile
|
||||
:
|
||||
/* else */ Generic::WarpTile;
|
||||
|
||||
// Use if constexpr to select the compile-time constants for the current GfxId
|
||||
bool fail = false;
|
||||
if constexpr(GfxId == 1200)
|
||||
{
|
||||
std::cout << "Using gfx120x-optimized parameters (template specialization)." << std::endl;
|
||||
}
|
||||
else if constexpr(GfxId == 900)
|
||||
{
|
||||
std::cout << "Using gfx90x-optimized parameters (template specialization)." << std::endl;
|
||||
}
|
||||
else
|
||||
{ // Fallback for GfxId == 0 or unknown
|
||||
std::cerr << "WARNING: No specific parameters for GfxId " << GfxId
|
||||
<< ". Using generic parameters." << std::endl;
|
||||
return fail;
|
||||
}
|
||||
|
||||
using BlockWarps =
|
||||
ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads
|
||||
// per warp, 512 threads in one block are NEEDED)
|
||||
using BlockTile =
|
||||
ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block)
|
||||
using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one
|
||||
// warp (64 threads))
|
||||
using WarpTile = ck_tile::sequence<1, 8 * selected_warp_tile>; // shape of one warpTile
|
||||
// (elements covered by one warp
|
||||
// (32/64 threads))
|
||||
using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread)
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
@@ -107,8 +142,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
int deviceId;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&deviceId));
|
||||
hipDeviceProp_t props;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, deviceId));
|
||||
std::string arch_name = props.gcnArchName;
|
||||
|
||||
if(data_type == "fp16" && (arch_name.find("gfx12") != std::string::npos))
|
||||
return run<ck_tile::half_t, 1200>(arg_parser) ? 0 : -2;
|
||||
else if(data_type == "fp16" && (arch_name.find("gfx908") != std::string::npos))
|
||||
return run<ck_tile::half_t, 900>(arg_parser) ? 0 : -2;
|
||||
else
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
std::cerr << "Unsupported data type: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
20
include/ck_tile/core/arch/warp_tile_size.hpp
Normal file
20
include/ck_tile/core/arch/warp_tile_size.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// ============================================================================
|
||||
// Architecture-specific parameter definitions
|
||||
// We'll define all parameters for all supported architectures here.
|
||||
// ============================================================================
|
||||
|
||||
// Parameters for gfx120x (using a namespace for organization or just global constexpr)
|
||||
namespace Gfx120x {
|
||||
constexpr ck_tile::index_t WarpTile = 32;
|
||||
}
|
||||
|
||||
// Parameters for gfx90x (example values, adjust as needed)
|
||||
namespace Gfx90x {
|
||||
constexpr ck_tile::index_t WarpTile = 64;
|
||||
}
|
||||
|
||||
// Generic Parameters - should never be used in this example
|
||||
// templated run function should only be instantiated for Gfx120x and Gfx90x
|
||||
namespace Generic {
|
||||
constexpr ck_tile::index_t WarpTile = -1;
|
||||
}
|
||||
Reference in New Issue
Block a user