mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Fixed GemmK size calculation.
This commit is contained in:
@@ -547,19 +547,18 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
input_right_pads,
|
||||
k_batch_initial);
|
||||
|
||||
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
|
||||
const auto& c_grid_desc_m_n = descs_initial[I2];
|
||||
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n, M01, N01, k_batch_initial);
|
||||
|
||||
// Get the total K dimension size so that we don't make split-K value too small.
|
||||
const auto k_size = a_grid_desc_kbatch_k0_m_k1.GetLength(I0) * K1Number * K0PerBlock;
|
||||
const auto gemmK = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
|
||||
|
||||
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
|
||||
// Hence, the grid is just size of the tile map.
|
||||
const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, k_size);
|
||||
k_dim_size_ = gemmK;
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
|
||||
@@ -521,18 +521,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
|
||||
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
|
||||
const index_t GemmK = a_grid_desc_kbatch_k0_m_k1.GetLength(I0) * a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
|
||||
|
||||
// nullptr for output, will be set after workspace set
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, 1};
|
||||
|
||||
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
|
||||
// Hence, the grid is just size of the tile map.
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(gemm_arg.M, gemm_arg.N, 1, 1);
|
||||
const auto grid_size = gdx * gdy * gdz;
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, GemmK);
|
||||
const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(GemmM, GemmN);
|
||||
k_dim_size_ = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
|
||||
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -10,8 +10,10 @@ namespace device {
|
||||
struct ArgumentSplitK
|
||||
{
|
||||
index_t k_batch() const { return k_batch_; }
|
||||
index_t k_dim_size() const { return k_dim_size_; }
|
||||
protected:
|
||||
index_t k_batch_;
|
||||
index_t k_batch_{-1};
|
||||
index_t k_dim_size_{-1};
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
@@ -27,18 +28,15 @@ struct DeviceProperties
|
||||
int num_cu_;
|
||||
};
|
||||
|
||||
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t k_size)
|
||||
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size)
|
||||
{
|
||||
static DeviceProperties device_properties;
|
||||
const int num_cu = device_properties.num_cu_;
|
||||
auto k_batch = 1;
|
||||
//constexpr ck::index_t min_k_per_batch = 16;
|
||||
//const auto max_split_k = math::integer_divide_ceil(k_size, min_k_per_batch);
|
||||
|
||||
const auto optimal_split = static_cast<ck::index_t>(std::floor((max_occupancy * num_cu) / (grid_size)));
|
||||
if (optimal_split > 1)
|
||||
{
|
||||
//k_batch = std::min(optimal_split, max_split_k);
|
||||
k_batch = optimal_split;
|
||||
}
|
||||
|
||||
@@ -46,14 +44,28 @@ inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, c
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] K-dim size: " << k_size << std::endl;
|
||||
//std::cout << "[SPLIT-K AUTODEDUCE] Max split-k value: " << max_split_k << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl;
|
||||
}
|
||||
return k_batch;
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
inline index_t get_bwd_weight_gemm_k(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths)
|
||||
{
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
// The input array has elements in the order: G, N, K, Do, Ho, Wo
|
||||
// GemmK = N * Do * Ho * Wo for the BWD weight pass.
|
||||
constexpr index_t spatial_offset = 3;
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
|
||||
return gemmK;
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -113,6 +113,16 @@ struct PerfResults
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void set_k_dim_size(ck::index_t k_dim_size)
|
||||
{
|
||||
if (k_dim_size_ > 0 && k_dim_size != k_dim_size_)
|
||||
{
|
||||
std::cerr << "Error: k_dim_size cannot be set multiple times. Old value " << k_dim_size_ << ". New value " << k_dim_size << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
k_dim_size_ = k_dim_size;
|
||||
}
|
||||
|
||||
// Global best results
|
||||
std::string best_op_name_;
|
||||
float best_avg_time_ = 0;
|
||||
@@ -135,6 +145,9 @@ struct PerfResults
|
||||
float opt_split_k_gb_per_sec_ = 0;
|
||||
ck::index_t opt_split_k_best_arg_ = 1;
|
||||
|
||||
// K-dim size
|
||||
ck::index_t k_dim_size_ = -1;
|
||||
|
||||
std::vector<std::tuple<std::string, ck::index_t, float>> ranking_;
|
||||
};
|
||||
|
||||
@@ -156,6 +169,7 @@ void write_perf_results_to_file(const PerfResults& perf_results_global,
|
||||
}
|
||||
file << res.opt_split_k_avg_time_ << separator
|
||||
<< res.opt_split_k_best_arg_ << separator
|
||||
<< res.k_dim_size_ << separator
|
||||
<< rank << separator
|
||||
<< total_num;
|
||||
};
|
||||
@@ -364,7 +378,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128};
|
||||
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128, 256};
|
||||
bool profile_all = true;
|
||||
if(split_k != "all")
|
||||
{
|
||||
@@ -421,6 +435,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
if (split_k_arg)
|
||||
{
|
||||
split_k_arg_value = split_k_arg->k_batch();
|
||||
const auto k_dim_size = split_k_arg->k_dim_size();
|
||||
if (k_dim_size > 0)
|
||||
{
|
||||
perf_results_local.set_k_dim_size(k_dim_size);
|
||||
perf_results_global.set_k_dim_size(k_dim_size);
|
||||
}
|
||||
supports_split_k_optimization = true;
|
||||
}
|
||||
|
||||
@@ -587,6 +607,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
<< std::get<0>(perf_results_global.get_ranking(perf_results_global.opt_split_k_best_op_name_, perf_results_global.opt_split_k_best_arg_))
|
||||
<< " / " << std::get<1>(perf_results_global.get_ranking(perf_results_global.opt_split_k_best_op_name_, perf_results_global.opt_split_k_best_arg_))
|
||||
<< std::endl;
|
||||
std::cerr << "K-dim size: " << perf_results_global.k_dim_size_ << std::endl;
|
||||
|
||||
write_perf_results_to_file(perf_results_global, perf_results_list);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user