mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
implement device batched gemm b scale for wmma (#2825)
* rebased on top of develop
* fixed missing shuffeling and wrong indexing
* added tests for batched_b_scale
* added missing files
* fixed wrong stride computation and removed k batching (for now) due to precision issues
* reinstated k-batching with PRNG constrained to -1..1
* added specialization of GeneratorTensor_3 for int4 and fixed internal overflow
* added k-batching to reference and increased tolerances for test
* changed gemm_b_scale and gemm_universal tests to use correct parameters
* adressed review commentsd
* ported fixes back to non-batched version of b_scale
* adressed review comments
* run clang-format on older commits
* add type-conversion to AccDataType and then to CDataType to exactly mimic GPU's behavior
* added newline at end of file
* reflected changes from muitl-abd branch in batched b_scale
* fixed gfx11 issue
* changed range for pki4 to -1...1 (-0.5...0.5 never really made sense for i4 anyway and always should have caused compiler errors, but since there was no int4 specialization of GeneratorTensor3 until now, this passed
* run clang format
* set range of i4 generation to 0...1 for upstream tests to pass. This replicated previous behavior, which however means that it is NOT properly tested.
* reduced range for pk_i4 even further to 0..0
* removed failing xld instances. Failure now uncovered now that tests were fixed
* removed generation of int4 values entierly
* divide B buffer by BPackedSize
---------
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
[ROCm/composable_kernel commit: c4b2da9cbd]
This commit is contained in:
@@ -289,7 +289,6 @@ int main(int argc, char* argv[])
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
@@ -303,7 +302,6 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
|
||||
@@ -275,7 +275,7 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1, 1});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
@@ -289,7 +289,7 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1, 1});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
|
||||
Reference in New Issue
Block a user