diff -Nru gemmlowp-0.0~git20190128.58825b1/BUILD gemmlowp-0.0~git20190708.a227af1/BUILD --- gemmlowp-0.0~git20190128.58825b1/BUILD 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/BUILD 2019-07-08 13:42:10.000000000 +0000 @@ -94,6 +94,7 @@ "fixedpoint/*.h", ]) + [ "internal/common.h", + "internal/detect_platform.h", ], visibility = ["//visibility:private"], ) diff -Nru gemmlowp-0.0~git20190128.58825b1/CONTRIBUTORS gemmlowp-0.0~git20190708.a227af1/CONTRIBUTORS --- gemmlowp-0.0~git20190128.58825b1/CONTRIBUTORS 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/CONTRIBUTORS 2019-07-08 13:42:10.000000000 +0000 @@ -18,6 +18,7 @@ Justine Tunney Mark J. Matthews Marie White +Suharsh Sivakumar Intel: Sagi Marcovich diff -Nru gemmlowp-0.0~git20190128.58825b1/debian/changelog gemmlowp-0.0~git20190708.a227af1/debian/changelog --- gemmlowp-0.0~git20190128.58825b1/debian/changelog 2019-02-24 02:24:17.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/debian/changelog 2019-08-13 01:57:58.000000000 +0000 @@ -1,3 +1,9 @@ +gemmlowp (0.0~git20190708.a227af1-1) unstable; urgency=medium + + * New upstream version 0.0~git20190708.a227af1 + + -- Mo Zhou Tue, 13 Aug 2019 01:57:58 +0000 + gemmlowp (0.0~git20190128.58825b1-1) unstable; urgency=medium * New upstream version 0.0~git20190128.58825b1 diff -Nru gemmlowp-0.0~git20190128.58825b1/doc/public.md gemmlowp-0.0~git20190708.a227af1/doc/public.md --- gemmlowp-0.0~git20190128.58825b1/doc/public.md 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/doc/public.md 2019-07-08 13:42:10.000000000 +0000 @@ -56,7 +56,7 @@ * `InputScalar`: The scalar type of the LHS and RHS operands. At the moment, this must be `std::uint8_t`. -* `OutputScalar`: The scalar type of the LHS and RHS operands. At the moment, +* `OutputScalar`: The scalar type of the result. At the moment, this must be `std::uint8_t`. * `BitDepthParams`: Defines the bit format of the input and output matrices and the required accuracy of the computation. At the moment, the only diff -Nru gemmlowp-0.0~git20190128.58825b1/fixedpoint/fixedpoint_avx.h gemmlowp-0.0~git20190708.a227af1/fixedpoint/fixedpoint_avx.h --- gemmlowp-0.0~git20190128.58825b1/fixedpoint/fixedpoint_avx.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/fixedpoint/fixedpoint_avx.h 2019-07-08 13:42:10.000000000 +0000 @@ -19,6 +19,7 @@ #include #include "fixedpoint.h" +#include "fixedpoint_sse.h" namespace gemmlowp { @@ -214,4 +215,4 @@ } // end namespace gemmlowp -#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ \ No newline at end of file +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ diff -Nru gemmlowp-0.0~git20190128.58825b1/fixedpoint/fixedpoint.h gemmlowp-0.0~git20190708.a227af1/fixedpoint/fixedpoint.h --- gemmlowp-0.0~git20190128.58825b1/fixedpoint/fixedpoint.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/fixedpoint/fixedpoint.h 2019-07-08 13:42:10.000000000 +0000 @@ -121,8 +121,8 @@ // in the overflow case, we just want to avoid undefined behavior. // // tIntegerType may be int32 or any narrower signed type. -template -tIntegerType ShiftLeft(tIntegerType a, int offset) { +template +tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { const std::int64_t wide_a = static_cast(a); const std::int64_t wide_shifted = wide_a * (1 << offset); const auto min = std::numeric_limits::min(); @@ -353,8 +353,8 @@ // Correctly-rounded-to-nearest division by a power-of-two. // Also known as a rounding arithmetic right shift. -template -inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { +template +inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { assert(exponent >= 0); assert(exponent <= 31); const IntegerType mask = Dup((1ll << exponent) - 1); diff -Nru gemmlowp-0.0~git20190128.58825b1/fixedpoint/fixedpoint_neon.h gemmlowp-0.0~git20190708.a227af1/fixedpoint/fixedpoint_neon.h --- gemmlowp-0.0~git20190128.58825b1/fixedpoint/fixedpoint_neon.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/fixedpoint/fixedpoint_neon.h 2019-07-08 13:42:10.000000000 +0000 @@ -115,6 +115,16 @@ } template <> +inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) { + return vshlq_s32(a, offset); +} + +template <> +inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) { + return vshlq_s16(a, offset); +} + +template <> inline int32x4_t ShiftRight(int32x4_t a, int offset) { return vshlq_s32(a, vdupq_n_s32(-offset)); } @@ -280,6 +290,22 @@ const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); return vrshlq_s16(fixed_up_x, shift_vec); +} + +template <> +inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) { + const int32x4_t shift_vec = vnegq_s32(exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +template <> +inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) { + const int16x8_t shift_vec = vnegq_s16(exponent); + const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); + const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); + return vrshlq_s16(fixed_up_x, shift_vec); } template diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/dispatch_gemm_shape.h gemmlowp-0.0~git20190708.a227af1/internal/dispatch_gemm_shape.h --- gemmlowp-0.0~git20190128.58825b1/internal/dispatch_gemm_shape.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/dispatch_gemm_shape.h 2019-07-08 13:42:10.000000000 +0000 @@ -85,6 +85,22 @@ } }; +template +struct TransposeImpl> { + typedef OutputStageScaleInt32ByFixedPointAndExponentPC SrcType; + static const VectorShape TransposedShape = TransposeVectorShape::Value; + typedef OutputStageScaleInt32ByFixedPointAndExponentPC + DstType; + static DstType Run(const SrcType& src) { + DstType dst; + dst.result_fixedpoint_multiplier = + Transpose(src.result_fixedpoint_multiplier); + dst.result_exponent = Transpose(src.result_exponent); + dst.result_offset_after_shift = src.result_offset_after_shift; + return dst; + } +}; + template struct TransposeImpl> { typedef OutputStageBiasAddition SrcType; diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/kernel_default.h gemmlowp-0.0~git20190708.a227af1/internal/kernel_default.h --- gemmlowp-0.0~git20190128.58825b1/internal/kernel_default.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/kernel_default.h 2019-07-08 13:42:10.000000000 +0000 @@ -20,74 +20,84 @@ #include "../public/bit_depth.h" #include "common.h" +#include "kernel.h" #include "kernel_reference.h" namespace gemmlowp { -template +template struct DefaultKernelImpl {}; // Partial specialization implementing the logic that if we want to use -// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall -// back to a generic kernel not taking advantage of LhsAlwaysNonzero. -template -struct DefaultKernelImpl - : DefaultKernelImpl {}; - -// Partial specialization implementing the logic that if we want to use // a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we // fall back to a generic kernel not taking advantage of // MaxProductIsLessThan4096. +template +struct DefaultKernelImpl + : DefaultKernelImpl {}; + +// Partial specialization implementing the logic that if we want to use +// a kernel for LhsNonZero but do not have such a kernel, then we fall +// back to a generic kernel not taking advantage of LhsNonZero. template -struct DefaultKernelImpl - : DefaultKernelImpl {}; +struct DefaultKernelImpl + : DefaultKernelImpl {}; template struct DefaultKernel : DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue * BitDepthParams::RhsRange::kMaxValue < 4096), - (BitDepthParams::LhsRange::kMinValue > 0)> {}; + (BitDepthParams::LhsRange::kMinValue >= 0), + (BitDepthParams::LhsRange::kMinValue > 0 || + (BitDepthParams::LhsRange::kMaxValue <= 127 && + BitDepthParams::LhsRange::kMinValue > -128))> {}; } // end namespace gemmlowp -#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \ - LhsAlwaysNonzero, Kernel) \ - namespace gemmlowp { \ - template <> \ - struct DefaultKernelImpl \ - : Kernel {}; \ +#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, IsUnsigned, \ + LhsAlwaysNonZero, Kernel) \ + namespace gemmlowp { \ + template <> \ + struct DefaultKernelImpl : Kernel {}; \ } +// User-provided int8 inputs is only supported in the NEON path currently. #if defined GEMMLOWP_NEON_32 #include "kernel_neon.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2) -GEMMLOWP_SET_DEFAULT_KERNEL(true, false, +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_32_Kernel12x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(true, true, false, NEON_32_Kernel12x4Depth2Assuming12BitProducts) -GEMMLOWP_SET_DEFAULT_KERNEL(false, true, +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, NEON_32bit_GEMM_Int8Operands_LhsNonzero) +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true, + NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs) #elif defined GEMMLOWP_NEON_64 #include "kernel_neon.h" #if defined GEMMLOWP_DOTPROD_KERNEL -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth4_dotprod) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, + NEON_64_Kernel12x8Depth4_dotprod) #else -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2) -GEMMLOWP_SET_DEFAULT_KERNEL(false, true, +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_64_Kernel12x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, NEON_64bit_GEMM_Int8Operands_LhsNonzero) #endif +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true, + NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs) #elif defined(GEMMLOWP_MSA) #include "kernel_msa.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2) -GEMMLOWP_SET_DEFAULT_KERNEL(false, true, MSA_GEMM_Int8Operands_LhsNonzero) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, MSA_Kernel12x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, MSA_GEMM_Int8Operands_LhsNonzero) #elif defined GEMMLOWP_SSE4_32 #include "kernel_sse.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_32_Kernel4x4Depth2) #elif defined GEMMLOWP_SSE4_64 #include "kernel_sse.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_64_Kernel12x4Depth2) #elif defined GEMMLOWP_AVX2_64 #include "kernel_avx.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, AVX2_64_Kernel24x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, AVX2_64_Kernel24x8Depth2) #else #include "kernel_reference.h" namespace gemmlowp { @@ -96,7 +106,7 @@ KernelSideFormat, 1> > > DefaultReferenceKernel; } -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, DefaultReferenceKernel) #endif #endif // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_ diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/kernel.h gemmlowp-0.0~git20190708.a227af1/internal/kernel.h --- gemmlowp-0.0~git20190128.58825b1/internal/kernel.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/kernel.h 2019-07-08 13:42:10.000000000 +0000 @@ -145,12 +145,24 @@ static const int kCells = tCells; static const int kWidth = kCells * Cell::kWidth; static const int kDepth = Cell::kDepth; - typedef std::uint8_t Scalar; + typedef std::uint8_t Scalar; // The scalar type of the Format. + typedef std::uint8_t InputScalar; // The scalar type of the original input. }; +// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but +// packs converts it to int8. template struct KernelSideFormatInt8 : KernelSideFormat { typedef std::int8_t Scalar; + typedef std::uint8_t InputScalar; +}; + +// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without +// pack conversion. +template +struct KernelSideFormatInt8Inputs : KernelSideFormat { + typedef std::int8_t Scalar; + typedef std::int8_t InputScalar; }; // KernelFormat describes fully the input data layout that a kernel expects. @@ -216,19 +228,24 @@ virtual ~KernelBase() {} }; -template +template struct ZeroPointInputValue {}; template <> -struct ZeroPointInputValue { +struct ZeroPointInputValue { static constexpr std::uint8_t kValue = 0; }; template <> -struct ZeroPointInputValue { +struct ZeroPointInputValue { static constexpr std::uint8_t kValue = 128; }; +template <> +struct ZeroPointInputValue { + static constexpr std::uint8_t kValue = 0; +}; + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_KERNEL_H_ diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/kernel_neon.h gemmlowp-0.0~git20190708.a227af1/internal/kernel_neon.h --- gemmlowp-0.0~git20190128.58825b1/internal/kernel_neon.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/kernel_neon.h 2019-07-08 13:42:10.000000000 +0000 @@ -924,6 +924,17 @@ } }; +// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that +// requires that user inputs were originally int8. This avoids the uint8->int8 +// conversion in the pack step. +struct NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs + : NEON_32bit_GEMM_Int8Operands_LhsNonzero { + typedef KernelFormat< + KernelSideFormatInt8Inputs, 1>, + KernelSideFormatInt8Inputs, 1> > + Format; +}; + #endif // GEMMLOWP_NEON_32 // The kernels here are specifically arm 64bit assembly, not arm 32bit. @@ -1265,6 +1276,17 @@ } }; +// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that +// requires that user inputs were originally int8. This avoids the uint8->int8 +// conversion in the pack step. +struct NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs + : NEON_64bit_GEMM_Int8Operands_LhsNonzero { + typedef KernelFormat< + KernelSideFormatInt8Inputs, 1>, + KernelSideFormatInt8Inputs, 1> > + Format; +}; + // Our main GEMM kernel. struct NEON_64_Kernel12x8Depth2 : KernelBase { typedef KernelFormat, 3>, diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/multi_thread_gemm.h gemmlowp-0.0~git20190708.a227af1/internal/multi_thread_gemm.h --- gemmlowp-0.0~git20190128.58825b1/internal/multi_thread_gemm.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/multi_thread_gemm.h 2019-07-08 13:42:10.000000000 +0000 @@ -296,7 +296,6 @@ // Thread entry point. void ThreadFunc() { ScopedProfilingLabel label("Worker::ThreadFunc"); - RegisterCurrentThreadForProfiling(); ChangeState(State::Ready); @@ -373,18 +372,41 @@ } } - void Execute(const std::vector& tasks) { - assert(tasks.size() >= 1); + // Just executes the tasks. Does not destroy them. Similar to + // ruy::ThreadPool::Execute. + template + void Execute(int tasks_count, TaskType* tasks) { + assert(tasks_count >= 1); // One of the tasks will be run on the current thread. - std::size_t workers_count = tasks.size() - 1; + std::size_t workers_count = tasks_count - 1; CreateWorkers(workers_count); assert(workers_count <= workers_.size()); counter_to_decrement_when_ready_.Reset(workers_count); - int n = 0; - std::for_each(tasks.begin(), --tasks.end(), - [this, &n](Task* task) { workers_[n++]->StartWork(task); }); + for (std::size_t i = 0; i < tasks_count - 1; i++) { + workers_[i]->StartWork(&tasks[i]); + } // Execute the remaining workload immediately on the current thread. - Task* task = tasks.back(); + Task* task = &tasks[tasks_count - 1]; + task->local_allocator = &main_thread_task_allocator_; + task->Run(); + // Wait for the workers submitted above to finish. + counter_to_decrement_when_ready_.Wait(); + } + + // Legacy: executes the tasks and destroys them + void LegacyExecuteAndDestroyTasks(const std::vector& tasks) { + std::size_t tasks_count = tasks.size(); + assert(tasks_count >= 1); + // One of the tasks will be run on the current thread. + std::size_t workers_count = tasks_count - 1; + CreateWorkers(workers_count); + assert(workers_count <= workers_.size()); + counter_to_decrement_when_ready_.Reset(workers_count); + for (int i = 0; i < tasks_count - 1; i++) { + workers_[i]->StartWork(tasks[i]); + } + // Execute the remaining workload immediately on the current thread. + Task* task = tasks[tasks_count - 1]; task->local_allocator = &main_thread_task_allocator_; task->Run(); // Wait for the workers submitted above to finish. @@ -394,6 +416,11 @@ std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; }); } + // Legacy old name of LegacyExecuteAndDestroyTasks + void Execute(const std::vector& tasks) { + LegacyExecuteAndDestroyTasks(tasks); + } + private: // Ensures that the pool has at least the given count of workers. // If any new worker has to be created, this function waits for it to diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/output.h gemmlowp-0.0~git20190708.a227af1/internal/output.h --- gemmlowp-0.0~git20190128.58825b1/internal/output.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/output.h 2019-07-08 13:42:10.000000000 +0000 @@ -22,6 +22,7 @@ #include #include #include +#include #include "../fixedpoint/fixedpoint.h" #include "../public/output_stages.h" @@ -179,7 +180,47 @@ int right_shift; }; -// Implementation of OutputStageSaturatingCastToUint8 for scalar data +template +struct OutputStageEvalImpl< + OutputStageScaleInt32ByFixedPointAndExponentPC, + RegisterBlock> { + typedef RegisterBlock InputType; + typedef RegisterBlock OutputType; + + typedef OutputStageScaleInt32ByFixedPointAndExponentPC OutputStage; + + OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input, int row, int col) const { + OutputType output; + const int pos = Shape == VectorShape::Row ? col : row; + using RegisterType = typename InputType::RegisterType; + const RegisterType result_offset_after_shift = + Dup(output_stage.result_offset_after_shift); + auto left_shift = + LoadForBroadcasting(output_stage.result_exponent, pos); + auto right_shift = + LoadForBroadcasting(output_stage.result_exponent, pos); + const auto result_fixedpoint_multiplier = LoadForBroadcasting( + output_stage.result_fixedpoint_multiplier, pos); + for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) { + left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0); + right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0); + } + const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul( + BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier); + const auto rdpot_val = + BroadcastRoundingDivideByPOT(mulhigh_val, right_shift); + for (int i = 0; i < InputType::kRegisterCount; i++) { + output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift); + } + return output; + } + + const OutputStage& output_stage; +}; + +// Implementation of OutputStageSaturatingCastToUint8 for scalar data. template struct OutputStageEvalBufferImpl> { @@ -202,7 +243,30 @@ } }; -// Implementation of OutputStageSaturatingCastToInt16 for scalar data +// Implementation of OutputStageSaturatingCastToInt8 for scalar data. +template +struct OutputStageEvalBufferImpl> { + typedef RegisterBuffer InputType; + typedef RegisterBuffer OutputType; + static_assert(InputType::kRegisterLanes == 1, + "This path is only for scalar values"); + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + for (int i = 0; i < InputType::kRegisterCount; i++) { + std::int32_t data = input.reg[i]; + output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data; + } + return output; + } +}; + +// Implementation of OutputStageSaturatingCastToInt16 for scalar data. template struct OutputStageEvalBufferImpl> { diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/output_neon.h gemmlowp-0.0~git20190708.a227af1/internal/output_neon.h --- gemmlowp-0.0~git20190128.58825b1/internal/output_neon.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/output_neon.h 2019-07-08 13:42:10.000000000 +0000 @@ -108,6 +108,90 @@ }; template <> +struct OutputStageEvalBufferImpl> { + typedef RegBufferInt32<4> InputType; + typedef RegBufferInt8<4> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x4_t res_16 = vqmovn_s32(input.reg[0]); + int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16)); + output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl> { + typedef RegBufferInt32<8> InputType; + typedef RegBufferInt8<8> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + output.reg[0] = vqmovn_s16(res_16); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl> { + typedef RegBufferInt32<16> InputType; + typedef RegBufferInt8<16> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16_0 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + int16x8_t res_16_1 = + vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); + output.reg[0] = vqmovn_s16(res_16_0); + output.reg[1] = vqmovn_s16(res_16_1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl> { + typedef RegBufferInt32<32> InputType; + typedef RegBufferInt8<32> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16[4]; + for (int i = 0; i < 4; i++) { + res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), + vqmovn_s32(input.reg[2 * i + 1])); + } + for (int i = 0; i < 4; i++) { + output.reg[i] = vqmovn_s16(res_16[i]); + } + return output; + } +}; + +template <> struct OutputStageEvalBufferImpl> { typedef RegBufferInt32<4> InputType; @@ -556,8 +640,8 @@ vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); } } else { + int row_stride = dst->rows_stride(); for (int i = 0; i < 4; i++) { - int row_stride = dst->rows_stride(); std::uint8_t* col_ptr = dst_ptr + i; vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); @@ -621,6 +705,153 @@ } } }; + +template +struct StoreFinalOutputImpl, DstType> { + static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row, + int col) { + const std::int32_t src_reg = src.buf.reg[0]; + for (int i = 0; i < 4; i++) { + *dst->data(row + i, col) = (src_reg >> (8 * i)); + } + } +}; + +template +struct StoreFinalOutputImpl, DstType> { + static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row, + int col) { + for (int i = 0; i < 4; i++) { + *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); + } + } +}; + +template +struct StoreFinalOutputImpl, DstType> { + static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row, + int col) { + std::int8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + vst1_s8(dst_ptr, src.buf.reg[0]); + } else { + const int row_stride = dst->rows_stride(); + vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); + vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); + vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); + vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); + vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); + vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); + vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); + vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); + } + } +}; + +template +struct StoreFinalOutputImpl, DstType> { + static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row, + int col) { + std::int8_t* dst_ptr = dst->data(row, col); + const int row_stride = dst->rows_stride(); + const int col_stride = dst->cols_stride(); + for (int i = 0; i < 2; i++) { + vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 0); + vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 1); + vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 2); + vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 3); + vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 4); + vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 5); + vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 6); + vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 7); + } + } +}; + +template +struct StoreFinalOutputImpl, DstType> { + static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row, + int col) { + std::int8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 4; i++) { + vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]); + } + } else { + int row_stride = dst->rows_stride(); + for (int i = 0; i < 4; i++) { + std::int8_t* col_ptr = dst_ptr + i; + vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); + vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); + vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); + vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); + vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); + vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); + vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); + vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); + } + } + } +}; + +inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) { + int8x8x2_t a[4]; + a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]); + a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]); + a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]); + a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]); + int16x4x2_t b[4]; + b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]), + vreinterpret_s16_s8(a[1].val[0])); + b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]), + vreinterpret_s16_s8(a[1].val[1])); + b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]), + vreinterpret_s16_s8(a[3].val[0])); + b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]), + vreinterpret_s16_s8(a[3].val[1])); + int32x2x2_t c[4]; + c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]), + vreinterpret_s32_s16(b[2].val[0])); + c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]), + vreinterpret_s32_s16(b[3].val[0])); + c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]), + vreinterpret_s32_s16(b[2].val[1])); + c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]), + vreinterpret_s32_s16(b[3].val[1])); + RegBlockInt8<8, 8> result; + result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]); + result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]); + result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]); + result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]); + result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]); + result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]); + result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]); + result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]); + return result; +} + +template +struct StoreFinalOutputImpl, DstType> { + static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row, + int col) { + const auto& block = + DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); + std::int8_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + for (int i = 0; i < 8; i++) { + vst1_s8(dst_ptr + i * stride, block.buf.reg[i]); + } + } +}; template struct StoreFinalOutputImpl, DstType> { diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/pack.h gemmlowp-0.0~git20190708.a227af1/internal/pack.h --- gemmlowp-0.0~git20190128.58825b1/internal/pack.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/pack.h 2019-07-08 13:42:10.000000000 +0000 @@ -72,6 +72,10 @@ pos_ += n * KernelSideFormat::Cell::kSize; } + // TODO(suharshs): The datatype can now be int8 as well. We could introduce a + // new int8 current_data impl as well. This change would propagate to all pack + // impls and the Kernel::Run API, which all assume uint8. For now we leave + // this as-is pending future refactor. const std::uint8_t* current_data() const { return allocator_->GetPointer(data_handle_) + pos_; } @@ -208,6 +212,7 @@ public: typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat; typedef typename KernelSideFormat::Cell CellFormat; + typedef typename KernelSideFormat::InputScalar KernelInputScalar; typedef typename KernelSideFormat::Scalar KernelScalar; static const int kCells = KernelSideFormat::kCells; static const int kCellWidth = CellFormat::kWidth; @@ -216,7 +221,7 @@ static const int kCellSize = CellFormat::kSize; static const SideMapOrder kSrcOrder = SrcMapType::kOrder; static const int kZeroPointInputValue = - ZeroPointInputValue::kValue; + ZeroPointInputValue::kValue; PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {} @@ -233,7 +238,7 @@ std::uint8_t buf_[kKernelWidth * kRegisterSize]; public: - // Selects a block if in-place source data that's already a complete block + // Selects a block if in-place source data that's already a complete block. void UseCompleteSrcInPlace(const SrcMapType& src) { complete_src_ = src; } // Copies an incomplete block of source data into a local temporary // complete block by zero-extending it. @@ -249,7 +254,10 @@ memcpy(buf_ + d * kKernelWidth, src.data(0, d), src.width()); } } - complete_src_ = SrcMapType(buf_, kKernelWidth, kRegisterSize); + + // Since the KernelInputScalar type may not be uint8, we need to cast buf_. + complete_src_ = SrcMapType(reinterpret_cast(buf_), + kKernelWidth, kRegisterSize); } // Packs a complete block into the destination. This is the most // critical part and the part that we most typically want to @@ -340,7 +348,7 @@ } } - // Prefetches the data that will be read by PackL1 + // Prefetches the data that will be read by PackL1. void PrefetchL1(int start_width, int width, int start_depth, int depth) { if (SrcMapType::kOrder == SideMapOrder::WidthMajor) { for (int d = 0; d < depth; d += kDefaultCacheLineSize) { @@ -394,7 +402,7 @@ const SrcMapType& src_map_; }; -// Packs a block of the input LHS matrix, into a PackedSideBlock +// Packs a block of the input LHS matrix, into a PackedSideBlock. template void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) { ScopedProfilingLabel label("pack LHS"); @@ -409,7 +417,7 @@ impl.PackL2(); } -// Packs a block of the input RHS matrix, into a PackedSideBlock +// Packs a block of the input RHS matrix, into a PackedSideBlock. template void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) { ScopedProfilingLabel label("pack RHS"); diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/pack_neon.h gemmlowp-0.0~git20190708.a227af1/internal/pack_neon.h --- gemmlowp-0.0~git20190128.58825b1/internal/pack_neon.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/pack_neon.h 2019-07-08 13:42:10.000000000 +0000 @@ -26,6 +26,9 @@ typedef SideMap WidthMajorUint8SideMap; +typedef SideMap + WidthMajorInt8SideMap; + template using DepthMajorSideFormatNCells4x2 = KernelSideFormat, Cells>; @@ -295,6 +298,67 @@ sums2[i] = vaddl_s8(lo, hi); } int16x8_t sums4[Width / 2]; + for (int i = 0; i < Width / 2; i++) { + sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]); + } + if (Width == 4) { + int32x4_t sum = vld1q_s32(sums_ptr); + int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]); + sum = vpadalq_s16(sum, sums8); + vst1q_s32(sums_ptr, sum); + } else { + assert(Width == 2); + int32x2_t sum = vld1_s32(sums_ptr); + int16x4_t sums8 = + vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0])); + sum = vpadal_s16(sum, sums8); + vst1_s32(sums_ptr, sum); + } + dst->seek_forward_n_cells(1); + } +}; + +template +using Int8InputsFastKernelFormat = + KernelSideFormatInt8Inputs, 1>; + +// Same as above, but for int8 inputs, avoiding the uint8 -> int8 conversion. +template +class PackingRegisterBlock>> + : public PackingRegisterBlockBase< + WidthMajorInt8SideMap, + PackedSideBlock>> { + public: + static_assert(Width == 2 || Width == 4, ""); + typedef Int8InputsFastKernelFormat KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock* dst, int start_width) { + std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width; + std::int8_t* dst_ptr = reinterpret_cast(dst->current_data()); + const std::int8_t* const src_ptr = this->complete_src_.data(); + const int stride = this->complete_src_.stride(); + // Load source WidthMajor data + int8x16_t src_lines[Width]; + for (int i = 0; i < Width; i++) { + src_lines[i] = vld1q_s8(src_ptr + i * stride); + } + for (int i = 0; i < Width; i++) { + vst1q_s8(dst_ptr + 16 * i, src_lines[i]); + } + int16x8_t sums2[Width]; + for (int i = 0; i < Width; i++) { + const int8x8_t lo = vget_low_s8(src_lines[i]); + const int8x8_t hi = vget_high_s8(src_lines[i]); + sums2[i] = vaddl_s8(lo, hi); + } + int16x8_t sums4[Width / 2]; for (int i = 0; i < Width / 2; i++) { sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]); } diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/platform.h gemmlowp-0.0~git20190708.a227af1/internal/platform.h --- gemmlowp-0.0~git20190128.58825b1/internal/platform.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/platform.h 2019-07-08 13:42:10.000000000 +0000 @@ -72,8 +72,8 @@ inline double real_time_in_seconds() { __int64 wintime; GetSystemTimeAsFileTime((FILETIME *)&wintime); - wintime -= 116444736000000000i64; // 1jan1601 to 1jan1970 - return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9; + wintime -= 116444736000000000LL; // 1jan1601 to 1jan1970 + return wintime / 10000000LL + wintime % 10000000LL * 100 * 1e-9; } #else diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/simd_wrappers_common_neon_sse.h gemmlowp-0.0~git20190708.a227af1/internal/simd_wrappers_common_neon_sse.h --- gemmlowp-0.0~git20190128.58825b1/internal/simd_wrappers_common_neon_sse.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/simd_wrappers_common_neon_sse.h 2019-07-08 13:42:10.000000000 +0000 @@ -350,6 +350,210 @@ } }; +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], Dup(rhs.buf.reg[0])); + return result; + } +}; + // 4x1 := 4x1 * 1x1 template <> struct BroadcastMulImpl, RegBlockInt32<1, 1>> { diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/simd_wrappers.h gemmlowp-0.0~git20190708.a227af1/internal/simd_wrappers.h --- gemmlowp-0.0~git20190128.58825b1/internal/simd_wrappers.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/simd_wrappers.h 2019-07-08 13:42:10.000000000 +0000 @@ -196,6 +196,153 @@ } template +struct BroadcastShiftLeftImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template +typename BroadcastBinaryOpRegisterBlock::Type BroadcastShiftLeft( + const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs; + return BroadcastShiftLeftImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template +struct BroadcastSaturatingRoundingDoublingHighMulImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template +typename BroadcastBinaryOpRegisterBlock::Type +BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs; + return BroadcastSaturatingRoundingDoublingHighMulImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template +struct BroadcastRoundingDivideByPOTImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template +typename BroadcastBinaryOpRegisterBlock::Type +BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs; + return BroadcastRoundingDivideByPOTImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template struct BroadcastMulImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock::Type; @@ -498,12 +645,16 @@ using RegBufferInt16 = RegisterBuffer; template using RegBufferUint8 = RegisterBuffer; +template +using RegBufferInt8 = RegisterBuffer; template using RegBlockInt32 = RegisterBlock; template using RegBlockInt16 = RegisterBlock; template using RegBlockUint8 = RegisterBlock; +template +using RegBlockInt8 = RegisterBlock; } // end namespace gemmlowp diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/simd_wrappers_neon.h gemmlowp-0.0~git20190708.a227af1/internal/simd_wrappers_neon.h --- gemmlowp-0.0~git20190128.58825b1/internal/simd_wrappers_neon.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/simd_wrappers_neon.h 2019-07-08 13:42:10.000000000 +0000 @@ -25,6 +25,7 @@ using Int16x4 = int16x4_t; using Int16x8 = int16x8_t; using Uint8x8 = uint8x8_t; +using Int8x8 = int8x8_t; template struct RegisterType { @@ -48,6 +49,14 @@ std::uint8_t>::type>::type; }; +template +struct RegisterType { + using Type = typename std::conditional< + ScalarCount >= 8, Int8x8, + typename std::conditional= 4, std::int32_t, + std::int8_t>::type>::type; +}; + inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); } inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); } inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); } @@ -92,6 +101,10 @@ inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); } +inline Int32x4 Max(Int32x4 a, std::int32_t b) { + return vmaxq_s32(a, vdupq_n_s32(b)); +} + inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) { return vqrdmulhq_n_s32(a, b); } @@ -164,6 +177,17 @@ }; template <> +struct LoadContiguousImpl> { + static RegBlockInt8<8, 8> Run(const std::int8_t* src) { + RegBlockInt8<8, 8> result; + for (int i = 0; i < 8; i++) { + result.buf.reg[i] = vld1_s8(src + 8 * i); + } + return result; + } +}; + +template <> struct LoadContiguousImpl> { static RegBlockInt32<8, 8> Run(const std::int32_t* src) { RegBlockInt32<8, 8> result; @@ -173,6 +197,352 @@ return result; } }; + +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastShiftLeftImpl, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = + RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = + RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = + RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = + RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = + RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = + RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = + RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = + RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = + RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = + RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], Dup(rhs.buf.reg[0])); + result.buf.reg[1] = + RoundingDivideByPOT(lhs.buf.reg[1], Dup(rhs.buf.reg[0])); + return result; + } +}; } // end namespace gemmlowp diff -Nru gemmlowp-0.0~git20190128.58825b1/internal/unpack.h gemmlowp-0.0~git20190708.a227af1/internal/unpack.h --- gemmlowp-0.0~git20190128.58825b1/internal/unpack.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/internal/unpack.h 2019-07-08 13:42:10.000000000 +0000 @@ -98,12 +98,14 @@ const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, int depth, int src_row, int src_col, int src_global_row, int src_global_col, int dst_row, int dst_col) { + using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar; using KernelLhsScalar = typename KernelFormat::Lhs::Scalar; + using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar; using KernelRhsScalar = typename KernelFormat::Rhs::Scalar; static constexpr int KernelLhsZeroPointInput = - ZeroPointInputValue::kValue; + ZeroPointInputValue::kValue; static constexpr int KernelRhsZeroPointInput = - ZeroPointInputValue::kValue; + ZeroPointInputValue::kValue; auto acc = Load(src, src_row, src_col); const auto& lhs_sums_of_each_slice_block = LoadForBroadcasting(lhs_sums_of_each_slice, src_row); diff -Nru gemmlowp-0.0~git20190128.58825b1/meta/generators/cc_emitter.py gemmlowp-0.0~git20190708.a227af1/meta/generators/cc_emitter.py --- gemmlowp-0.0~git20190128.58825b1/meta/generators/cc_emitter.py 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/meta/generators/cc_emitter.py 2019-07-08 13:42:10.000000000 +0000 @@ -52,16 +52,16 @@ self.indent = self.indent[:-2] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff -Nru gemmlowp-0.0~git20190128.58825b1/meta/generators/common.py gemmlowp-0.0~git20190708.a227af1/meta/generators/common.py --- gemmlowp-0.0~git20190128.58825b1/meta/generators/common.py 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/meta/generators/common.py 2019-07-08 13:42:10.000000000 +0000 @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """.""" +import collections _HEADER_COPYRIGHT = ( '''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. @@ -71,7 +72,7 @@ self.emitter = emitter def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers): - if callable(getattr(self, 'EmitPack', None)): + if isinstance(getattr(self, 'EmitPack', None), collections.Callable): template_params = [in_type, lanes_count, pack_size, leftovers, self.name] self.emitter.EmitMemberFunctionBegin( 'Stream', [], template_params, 'Pack', diff -Nru gemmlowp-0.0~git20190128.58825b1/meta/generators/neon_emitter_64.py gemmlowp-0.0~git20190708.a227af1/meta/generators/neon_emitter_64.py --- gemmlowp-0.0~git20190128.58825b1/meta/generators/neon_emitter_64.py 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/meta/generators/neon_emitter_64.py 2019-07-08 13:42:10.000000000 +0000 @@ -423,7 +423,7 @@ self.indent = self.indent[:-delta] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def PushOp(self, op): if op in self.ops.keys(): @@ -435,13 +435,13 @@ self.ops.clear() def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff -Nru gemmlowp-0.0~git20190128.58825b1/meta/generators/neon_emitter.py gemmlowp-0.0~git20190708.a227af1/meta/generators/neon_emitter.py --- gemmlowp-0.0~git20190128.58825b1/meta/generators/neon_emitter.py 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/meta/generators/neon_emitter.py 2019-07-08 13:42:10.000000000 +0000 @@ -187,7 +187,7 @@ self.indent = self.indent[:-delta] def EmitIndented(self, what): - print self.indent + what + print(self.indent + what) def PushOp(self, op): if op in self.ops.keys(): @@ -199,13 +199,13 @@ self.ops.clear() def EmitNewline(self): - print '' + print('') def EmitPreprocessor1(self, op, param): - print '#%s %s' % (op, param) + print('#%s %s' % (op, param)) def EmitPreprocessor(self, op): - print '#%s' % op + print('#%s' % op) def EmitInclude(self, include): self.EmitPreprocessor1('include', include) diff -Nru gemmlowp-0.0~git20190128.58825b1/profiling/instrumentation.h gemmlowp-0.0~git20190708.a227af1/profiling/instrumentation.h --- gemmlowp-0.0~git20190128.58825b1/profiling/instrumentation.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/profiling/instrumentation.h 2019-07-08 13:42:10.000000000 +0000 @@ -108,7 +108,7 @@ // contains pointers to literal strings that were manually entered // in the instrumented code (see ScopedProfilingLabel). struct ProfilingStack { - static const std::size_t kMaxSize = 14; + static const std::size_t kMaxSize = 30; typedef const char* LabelsArrayType[kMaxSize]; LabelsArrayType labels; std::size_t size; diff -Nru gemmlowp-0.0~git20190128.58825b1/public/bit_depth.h gemmlowp-0.0~git20190708.a227af1/public/bit_depth.h --- gemmlowp-0.0~git20190128.58825b1/public/bit_depth.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/public/bit_depth.h 2019-07-08 13:42:10.000000000 +0000 @@ -24,14 +24,15 @@ struct OperandRange { static const int kMinValue = tMinValue; static const int kMaxValue = tMaxValue; - static_assert(0 <= kMinValue, ""); static_assert(kMinValue < kMaxValue, ""); - static_assert(kMaxValue <= 255, ""); }; using Uint8Range = OperandRange<0, 255>; using Uint8RangeExcludingZero = OperandRange<1, 255>; +using Int8Range = OperandRange<-128, 127>; +using Int8RangeExcludingLow = OperandRange<-127, 127>; + template struct BitDepthParams { using LhsRange = tLhsRange; @@ -47,6 +48,11 @@ using L8R8WithLhsNonzeroBitDepthParams = BitDepthParams; +// Signed Variant: This allows using faster kernels using signed arithmetic, see +// NEON_64bit_GEMM_Int8Operands_Int32Accumulators_AccumTwoWithin16Bits +using SignedL8R8WithLhsNonzeroBitDepthParams = + BitDepthParams; + // Deprecated: when gemmlowp used to allow requantizing 8bit // inputs to less-than-8-bit depths, the public setting allowing // that was DefaultL7R5BitDepthParams. That requantization diff -Nru gemmlowp-0.0~git20190128.58825b1/public/output_stages.h gemmlowp-0.0~git20190708.a227af1/public/output_stages.h --- gemmlowp-0.0~git20190128.58825b1/public/output_stages.h 2019-01-28 15:33:33.000000000 +0000 +++ gemmlowp-0.0~git20190708.a227af1/public/output_stages.h 2019-07-08 13:42:10.000000000 +0000 @@ -138,12 +138,37 @@ std::int32_t result_offset_after_shift; }; +// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift' +// is not necessarily just a right shift, so we can represent multipliers +// greater than 1. This takes an result_exponent parameter; when it's +// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint +// with result_shift = -result_exponent. +// In the general case, this consists in first left-shifting by +// std::max(result_exponent, 0), before doing the same as +// OutputStageQuantizeDownInt32ByFixedPoint with +// result_shift = std::max(-result_exponent, 0). +// +// Difference from OutputStageScaleInt32ByFixedPointAndExponent here is that +// each row or column of the output (depending on tShape) has its own +// result_fixedpoint_multiplier and result_exponent numbers. +template +struct OutputStageScaleInt32ByFixedPointAndExponentPC { + VectorMap result_fixedpoint_multiplier; + VectorMap result_exponent; + std::int32_t result_offset_after_shift; +}; + // This output stage takes int32 values that are expected to be already // on the final uint8 scale, but not necessarily in the [0..255] range. // It clamps them to the [0..255] range and returns them casted to uint8. struct OutputStageSaturatingCastToUint8 {}; // This output stage takes int32 values that are expected to be already +// on the final int8 scale, but not necessarily in the [-128..127] range. +// It clamps them to the [-128..127] range and returns them casted to int8. +struct OutputStageSaturatingCastToInt8 {}; + +// This output stage takes int32 values that are expected to be already // in the [0..255] range and returns them casted to uint8. // This stage can save time if used instead of the // OutputStageSaturatingCastToUint8 stage immediately after the