Fixed GemmK size calculation.

This commit is contained in:
Ville Pietilä
2025-06-19 05:38:51 +00:00
parent 9d2f58de3a
commit e18e53ed16
5 changed files with 51 additions and 23 deletions

View File

@@ -547,19 +547,18 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
input_right_pads,
k_batch_initial);
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
const auto& c_grid_desc_m_n = descs_initial[I2];
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n, M01, N01, k_batch_initial);
// Get the total K dimension size so that we don't make split-K value too small.
const auto k_size = a_grid_desc_kbatch_k0_m_k1.GetLength(I0) * K1Number * K0PerBlock;
const auto gemmK = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
// Hence, the grid is just size of the tile map.
const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n);
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, k_size);
k_dim_size_ = gemmK;
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
}
else {
else
{
k_batch_ = split_k;
}

View File

@@ -521,18 +521,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const index_t GemmK = a_grid_desc_kbatch_k0_m_k1.GetLength(I0) * a_grid_desc_kbatch_k0_m_k1.GetLength(I2);
// nullptr for output, will be set after workspace set
typename GridwiseGemm::Argument gemm_arg{
nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, 1};
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
// Hence, the grid is just size of the tile map.
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(gemm_arg.M, gemm_arg.N, 1, 1);
const auto grid_size = gdx * gdy * gdz;
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, GemmK);
const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(GemmM, GemmN);
k_dim_size_ = get_bwd_weight_gemm_k<NDimSpatial>(a_g_n_k_wos_lengths);
k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size);
}
else
{

View File

@@ -10,8 +10,10 @@ namespace device {
struct ArgumentSplitK
{
index_t k_batch() const { return k_batch_; }
index_t k_dim_size() const { return k_dim_size_; }
protected:
index_t k_batch_;
index_t k_batch_{-1};
index_t k_dim_size_{-1};
};
} // namespace device

View File

@@ -4,6 +4,7 @@
#pragma once
#include <hip/hip_runtime.h>
#include "ck/utility/env.hpp"
#include "ck/utility/number.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/ck.hpp"
@@ -27,18 +28,15 @@ struct DeviceProperties
int num_cu_;
};
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t k_size)
inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size)
{
static DeviceProperties device_properties;
const int num_cu = device_properties.num_cu_;
auto k_batch = 1;
//constexpr ck::index_t min_k_per_batch = 16;
//const auto max_split_k = math::integer_divide_ceil(k_size, min_k_per_batch);
const auto optimal_split = static_cast<ck::index_t>(std::floor((max_occupancy * num_cu) / (grid_size)));
if (optimal_split > 1)
{
//k_batch = std::min(optimal_split, max_split_k);
k_batch = optimal_split;
}
@@ -46,14 +44,28 @@ inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, c
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] K-dim size: " << k_size << std::endl;
//std::cout << "[SPLIT-K AUTODEDUCE] Max split-k value: " << max_split_k << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl;
}
return k_batch;
}
template <ck::index_t NDimSpatial>
inline index_t get_bwd_weight_gemm_k(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths)
{
static constexpr auto I1 = Number<1>{};
// The input array has elements in the order: G, N, K, Do, Ho, Wo
// GemmK = N * Do * Ho * Wo for the BWD weight pass.
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
return gemmK;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -113,6 +113,16 @@ struct PerfResults
return ss.str();
}
void set_k_dim_size(ck::index_t k_dim_size)
{
if (k_dim_size_ > 0 && k_dim_size != k_dim_size_)
{
std::cerr << "Error: k_dim_size cannot be set multiple times. Old value " << k_dim_size_ << ". New value " << k_dim_size << std::endl;
exit(EXIT_FAILURE);
}
k_dim_size_ = k_dim_size;
}
// Global best results
std::string best_op_name_;
float best_avg_time_ = 0;
@@ -135,6 +145,9 @@ struct PerfResults
float opt_split_k_gb_per_sec_ = 0;
ck::index_t opt_split_k_best_arg_ = 1;
// K-dim size
ck::index_t k_dim_size_ = -1;
std::vector<std::tuple<std::string, ck::index_t, float>> ranking_;
};
@@ -156,6 +169,7 @@ void write_perf_results_to_file(const PerfResults& perf_results_global,
}
file << res.opt_split_k_avg_time_ << separator
<< res.opt_split_k_best_arg_ << separator
<< res.k_dim_size_ << separator
<< rank << separator
<< total_num;
};
@@ -364,7 +378,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128};
std::vector<ck::index_t> split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128, 256};
bool profile_all = true;
if(split_k != "all")
{
@@ -421,6 +435,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
if (split_k_arg)
{
split_k_arg_value = split_k_arg->k_batch();
const auto k_dim_size = split_k_arg->k_dim_size();
if (k_dim_size > 0)
{
perf_results_local.set_k_dim_size(k_dim_size);
perf_results_global.set_k_dim_size(k_dim_size);
}
supports_split_k_optimization = true;
}
@@ -587,6 +607,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
<< std::get<0>(perf_results_global.get_ranking(perf_results_global.opt_split_k_best_op_name_, perf_results_global.opt_split_k_best_arg_))
<< " / " << std::get<1>(perf_results_global.get_ranking(perf_results_global.opt_split_k_best_op_name_, perf_results_global.opt_split_k_best_arg_))
<< std::endl;
std::cerr << "K-dim size: " << perf_results_global.k_dim_size_ << std::endl;
write_perf_results_to_file(perf_results_global, perf_results_list);
}