WIP: Oversubscription factor.

This commit is contained in:
Ville Pietilä
2025-06-19 15:12:21 +00:00
parent e18e53ed16
commit 385defc8cd
5 changed files with 54 additions and 8 deletions

View File

@@ -549,13 +549,29 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto& c_grid_desc_m_n = descs_initial[I2];
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n, M01, N01, k_batch_initial);
const auto gemmK = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
// Hence, the grid is just size of the tile map.
const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n);
k_dim_size_ = gemmK;
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
k_dim_size_ = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
const bool enable_oversubscription = k_dim_size_ > 1 << 13;
// For small GemmK size, cap the max value of the k_batch.
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, BlockSize, enable_oversubscription);
const auto k_batch_max = static_cast<index_t>((k_dim_size_ - 1) / K0PerBlock);
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_dim_size: "
<< k_dim_size_ << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] K0PerBlock: " << K0PerBlock << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: "
<< k_batch_max << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal k_batch value: "
<< k_batch_ << std::endl;
k_batch_ = std::min(k_batch_, k_batch_max);
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: "
<< k_batch_ << std::endl;
}
}
else
{

View File

@@ -526,7 +526,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Hence, the grid is just size of the tile map.
const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(GemmM, GemmN);
k_dim_size_ = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
const bool enable_oversubscription = k_dim_size_ > 1 << 13;
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, BlockSize, enable_oversubscription);
// Cap the k_batch_ value such that it doesn't violate the limit for the number of prefetch stages for the pipeline.
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
{
const auto k_batch_max = static_cast<index_t>(std::floor(
(k_dim_size_ - 1.0) / ((GridwiseGemm::BlockwiseGemmPipe::PrefetchStages-1.0) * K0PerBlock)));
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: "
<< k_batch_max << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal k_batch value: "
<< k_batch_ << std::endl;
k_batch_ = std::min(k_batch_, k_batch_max);
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: "
<< k_batch_ << std::endl;
}
}
}
else
{

View File

@@ -24,25 +24,35 @@ struct DeviceProperties
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu_ = dev_prop.multiProcessorCount;
max_num_active_wavefronts_per_cu_ = dev_prop.maxThreadsPerMultiProcessor / dev_prop.warpSize;
wavefront_size_ = dev_prop.warpSize;
};
int num_cu_;
int max_num_active_wavefronts_per_cu_;
int wavefront_size_;
};
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size)
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t blockSize, bool enable_oversubscription = true)
{
static DeviceProperties device_properties;
const int num_cu = device_properties.num_cu_;
auto k_batch = 1;
const auto optimal_split = static_cast<ck::index_t>(std::floor((max_occupancy * num_cu) / (grid_size)));
const ck::index_t oversubscription = enable_oversubscription
? static_cast<ck::index_t>(std::round((1.0 *device_properties.max_num_active_wavefronts_per_cu_ * device_properties.wavefront_size_) / blockSize))
: 1;
const auto optimal_split = static_cast<ck::index_t>(std::floor((1.0 *max_occupancy * num_cu) / (grid_size)));
if (optimal_split > 1)
{
k_batch = optimal_split;
k_batch = oversubscription * optimal_split;
}
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Block size: " << blockSize << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Oversubscription factor: " << oversubscription << " (oversubscription enabled = " << std::to_string(enable_oversubscription) << ")"<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl;

View File

@@ -556,7 +556,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
{
return false;
}
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
{

View File

@@ -378,7 +378,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128, 256};
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128};
bool profile_all = true;
if(split_k != "all")
{