resolved floating point error for some instances

This commit is contained in:
Astha Rai
2025-06-23 18:42:55 +00:00
parent 694c2eaadb
commit 62e4a80be5

View File

@@ -159,25 +159,15 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
const auto Run = [&](const auto& kernel) {
dim3 grid_dim;
if(arg.Grid_size < 0)
{
// printf("grid size is less than 0");
int occupancy, num_cu;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
arg.Grid_size = num_cu * occupancy;
grid_dim = arg.Grid_size;
}
else
{
// printf("grid size is not 0");
grid_dim = arg.Grid_size;
}
// printf("grid size is less than 0");
int occupancy /**, num_cu**/;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
// num_cu = dev_prop.multiProcessorCount;
grid_dim = arg.block_2_ctile_map_streamk.get_grid_dims();
if(stream_config.flush_cache)
@@ -749,6 +739,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
const auto calculate_grid_size = [&](const auto& kernel) {
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
// printf("init occupancy: %d/n", occupancy);
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
@@ -758,10 +749,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if(has_main_k_block_loop)
{
// printf("into the main k block loop\n");
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
// printf("Case 1\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -775,6 +768,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
// printf("Case 2\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
@@ -784,6 +778,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
// printf("Case 3\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
@@ -796,6 +791,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
// printf("Case 4\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -810,6 +806,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
// printf("Case 5\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -824,6 +821,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
// printf("Case 6\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -838,6 +836,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
// printf("Case 7\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -852,6 +851,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
// printf("Case 8\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -866,6 +866,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
// printf("Case 9\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
@@ -882,6 +883,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
// printf("Case 10\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
@@ -892,6 +894,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
else
{
// printf("Case 11\n");
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
@@ -906,6 +909,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
// printf("Case 12\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
@@ -915,6 +919,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
else
{
// printf("Case 13\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
@@ -926,17 +931,34 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
else
{
// printf("not main k block loop\n");
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
{
// printf("Case 14\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
calculate_grid_size(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// printf("Case 15\n");
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
calculate_grid_size(kernel);
}
}
// printf("num_cu: %u\n", static_cast<uint32_t>(num_cu));
// printf("occupancy: %u\n", static_cast<uint32_t>(occupancy));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),