mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
turbo config
This commit is contained in:
@@ -146,6 +146,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
int result1 = !run_grouped_gemm_example<GemmConfigComputeTurbo>(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
|
||||
@@ -101,6 +101,22 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeTurbo : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
{
|
||||
|
||||
@@ -259,15 +259,6 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
@@ -284,12 +275,6 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
}
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
@@ -311,13 +296,6 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
@@ -349,16 +327,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
}
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
@@ -444,17 +413,6 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
@@ -505,6 +463,19 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
@@ -541,13 +512,6 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
@@ -569,13 +533,6 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
|
||||
Reference in New Issue
Block a user