mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
WIP: Oversubscription factor.
This commit is contained in:
@@ -549,13 +549,29 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
const auto& c_grid_desc_m_n = descs_initial[I2];
|
||||
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n, M01, N01, k_batch_initial);
|
||||
const auto gemmK = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
|
||||
|
||||
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
|
||||
// Hence, the grid is just size of the tile map.
|
||||
const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
k_dim_size_ = gemmK;
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
|
||||
k_dim_size_ = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
|
||||
const bool enable_oversubscription = k_dim_size_ > 1 << 13;
|
||||
|
||||
// For small GemmK size, cap the max value of the k_batch.
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, BlockSize, enable_oversubscription);
|
||||
const auto k_batch_max = static_cast<index_t>((k_dim_size_ - 1) / K0PerBlock);
|
||||
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_dim_size: "
|
||||
<< k_dim_size_ << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] K0PerBlock: " << K0PerBlock << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: "
|
||||
<< k_batch_max << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal k_batch value: "
|
||||
<< k_batch_ << std::endl;
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: "
|
||||
<< k_batch_ << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -526,7 +526,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Hence, the grid is just size of the tile map.
|
||||
const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(GemmM, GemmN);
|
||||
k_dim_size_ = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
|
||||
const bool enable_oversubscription = k_dim_size_ > 1 << 13;
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, BlockSize, enable_oversubscription);
|
||||
|
||||
// Cap the k_batch_ value such that it doesn't violate the limit for the number of prefetch stages for the pipeline.
|
||||
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto k_batch_max = static_cast<index_t>(std::floor(
|
||||
(k_dim_size_ - 1.0) / ((GridwiseGemm::BlockwiseGemmPipe::PrefetchStages-1.0) * K0PerBlock)));
|
||||
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: "
|
||||
<< k_batch_max << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal k_batch value: "
|
||||
<< k_batch_ << std::endl;
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: "
|
||||
<< k_batch_ << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -24,25 +24,35 @@ struct DeviceProperties
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
|
||||
num_cu_ = dev_prop.multiProcessorCount;
|
||||
max_num_active_wavefronts_per_cu_ = dev_prop.maxThreadsPerMultiProcessor / dev_prop.warpSize;
|
||||
wavefront_size_ = dev_prop.warpSize;
|
||||
};
|
||||
int num_cu_;
|
||||
int max_num_active_wavefronts_per_cu_;
|
||||
int wavefront_size_;
|
||||
};
|
||||
|
||||
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size)
|
||||
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t blockSize, bool enable_oversubscription = true)
|
||||
{
|
||||
static DeviceProperties device_properties;
|
||||
const int num_cu = device_properties.num_cu_;
|
||||
auto k_batch = 1;
|
||||
|
||||
const auto optimal_split = static_cast<ck::index_t>(std::floor((max_occupancy * num_cu) / (grid_size)));
|
||||
const ck::index_t oversubscription = enable_oversubscription
|
||||
? static_cast<ck::index_t>(std::round((1.0 *device_properties.max_num_active_wavefronts_per_cu_ * device_properties.wavefront_size_) / blockSize))
|
||||
: 1;
|
||||
|
||||
const auto optimal_split = static_cast<ck::index_t>(std::floor((1.0 *max_occupancy * num_cu) / (grid_size)));
|
||||
if (optimal_split > 1)
|
||||
{
|
||||
k_batch = optimal_split;
|
||||
k_batch = oversubscription * optimal_split;
|
||||
}
|
||||
|
||||
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Block size: " << blockSize << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Oversubscription factor: " << oversubscription << " (oversubscription enabled = " << std::to_string(enable_oversubscription) << ")"<< std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl;
|
||||
|
||||
@@ -556,7 +556,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
|
||||
{
|
||||
|
||||
@@ -378,7 +378,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128, 256};
|
||||
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128};
|
||||
bool profile_all = true;
|
||||
if(split_k != "all")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user