mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Added wmma support for gemm quantization: (#2841)
- profiler for gemm quantization for DL/XDL - tests for gemm quantization for DL/XDL - implementation for gemm quantization for WMMA - profiler/tests for gemm qunatization for WMMA Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2723dbd332
commit
f97b2a3f5d
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@@ -171,8 +172,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
|
||||
// be odd.
|
||||
constexpr bool AtomicsImplementationExists =
|
||||
!(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>) ||
|
||||
!(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
|
||||
std::is_same_v<EDataType, int8_t>) ||
|
||||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
|
||||
@@ -1065,6 +1065,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<EDataType>, int8_t>::value)
|
||||
{
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch
|
||||
<< " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user