19# include "config_auto.h"
31#if defined(_MSC_VER) && !defined(__clang__)
42# define PARALLEL_IF_OPENMP(__num_threads) \
43 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
44 PRAGMA(omp sections nowait) { \
46# define SECTION_IF_OPENMP \
49# define END_PARALLEL_IF_OPENMP \
56# define PRAGMA(x) __pragma(x)
58# define PRAGMA(x) _Pragma(# x)
62# define PARALLEL_IF_OPENMP(__num_threads)
63# define SECTION_IF_OPENMP
64# define END_PARALLEL_IF_OPENMP
76static inline uint32_t ceil_log2(uint32_t n) {
80 uint32_t l2 = 31 - __builtin_clz(n);
81#elif defined(_MSC_VER)
84 _BitScanReverse(&l2, n);
98 return (n == (1u << l2)) ? l2 : l2 + 1;
106 , is_2d_(two_dimensional)
109 if (two_dimensional) {
138 if (softmax_ !=
nullptr) {
159 for (
int w = 0; w <
WT_COUNT; ++w) {
168 if (softmax_ !=
nullptr) {
178 for (
int w = 0; w <
WT_COUNT; ++w) {
185 if (softmax_ !=
nullptr) {
194 if (softmax_ !=
nullptr) {
203 for (
int w = 0; w <
WT_COUNT; ++w) {
209 if (softmax_ !=
nullptr) {
216 for (
int w = 0; w <
WT_COUNT; ++w) {
220 std::ostringstream msg;
221 msg <<
name_ <<
" Gate weights " << w;
222 gate_weights_[w].
Debug2D(msg.str().c_str());
224 if (softmax_ !=
nullptr) {
237 for (
int w = 0; w <
WT_COUNT; ++w) {
245 if (softmax_ !=
nullptr && !softmax_->
Serialize(fp)) {
260 nf_ = ceil_log2(
no_);
265 for (
int w = 0; w <
WT_COUNT; ++w) {
274 is_2d_ = na_ - nf_ ==
ni_ + 2 * ns_;
280 if (softmax_ ==
nullptr) {
294 input_width_ = input.
Width();
295 if (softmax_ !=
nullptr) {
302 ResizeForward(input);
309 for (
auto &temp_line : temp_lines) {
310 temp_line.Init(ns_, ro, scratch);
314 curr_state.
Init(ns_, scratch);
315 ZeroVector<TFloat>(ns_, curr_state);
316 curr_output.
Init(ns_, scratch);
317 ZeroVector<TFloat>(ns_, curr_output);
322 std::vector<NetworkScratch::FloatVec> states, outputs;
324 states.resize(buf_width);
325 outputs.resize(buf_width);
326 for (
int i = 0;
i < buf_width; ++
i) {
327 states[
i].Init(ns_, scratch);
328 ZeroVector<TFloat>(ns_, states[
i]);
329 outputs[
i].Init(ns_, scratch);
330 ZeroVector<TFloat>(ns_, outputs[
i]);
336 if (softmax_ !=
nullptr) {
337 softmax_output.
Init(
no_, scratch);
338 ZeroVector<TFloat>(
no_, softmax_output);
339 int rounded_softmax_inputs = gate_weights_[
CI].
RoundInputs(ns_);
341 int_output.
Resize2d(
true, 1, rounded_softmax_inputs, scratch);
346 curr_input.
Init(na_, scratch);
351 int t = src_index.
t();
353 bool valid_2d =
Is2D();
361 int mod_t =
Modulo(t, buf_width);
364 if (softmax_ !=
nullptr) {
385 FuncInplace<GFunc>(ns_, temp_lines[
CI]);
394 FuncInplace<FFunc>(ns_, temp_lines[
GI]);
403 FuncInplace<FFunc>(ns_, temp_lines[
GF1]);
412 FuncInplace<FFunc>(ns_, temp_lines[
GFS]);
422 FuncInplace<FFunc>(ns_, temp_lines[
GO]);
429 int8_t *which_fg_col = which_fg_[t];
430 memset(which_fg_col, 1, ns_ *
sizeof(which_fg_col[0]));
432 const TFloat *stepped_state = states[mod_t];
433 for (
int i = 0;
i < ns_; ++
i) {
434 if (temp_lines[
GF1][
i] < temp_lines[
GFS][
i]) {
435 curr_state[
i] = temp_lines[
GFS][
i] * stepped_state[
i];
454 FuncMultiply<HFunc>(curr_state, temp_lines[
GO], ns_, curr_output);
458 if (softmax_ !=
nullptr) {
465 output->WriteTimeStep(t, softmax_output);
472 output->WriteTimeStep(dest_index.
t(), curr_output);
476 output->WriteTimeStep(t, curr_output);
486 ZeroVector<TFloat>(ns_, curr_state);
487 ZeroVector<TFloat>(ns_, curr_output);
498#ifndef GRAPHICS_DISABLED
509#ifndef GRAPHICS_DISABLED
518 outputerr.
Init(ns_, scratch);
521 curr_stateerr.
Init(ns_, scratch);
522 curr_sourceerr.
Init(na_, scratch);
523 ZeroVector<TFloat>(ns_, curr_stateerr);
524 ZeroVector<TFloat>(na_, curr_sourceerr);
527 for (
auto &gate_error : gate_errors) {
528 gate_error.
Init(ns_, scratch);
533 std::vector<NetworkScratch::FloatVec> stateerr, sourceerr;
535 stateerr.resize(buf_width);
536 sourceerr.resize(buf_width);
537 for (
int t = 0; t < buf_width; ++t) {
538 stateerr[t].Init(ns_, scratch);
539 sourceerr[t].Init(na_, scratch);
540 ZeroVector<TFloat>(ns_, stateerr[t]);
541 ZeroVector<TFloat>(na_, sourceerr[t]);
546 for (
auto &sourceerr_temp : sourceerr_temps) {
547 sourceerr_temp.
Init(na_, scratch);
549 int width = input_width_;
552 for (
auto &w : gate_errors_t) {
553 w.
Init(ns_, width, scratch);
558 if (softmax_ !=
nullptr) {
559 softmax_errors.
Init(
no_, scratch);
560 softmax_errors_t.
Init(
no_, width, scratch);
565 fwd_deltas.
Print(10);
573 int t = dest_index.
t();
583 up_pos = up_index.
t();
589 down_pos = down_index.
t();
594 int mod_t =
Modulo(t, buf_width);
597 ZeroVector<TFloat>(na_, curr_sourceerr);
598 ZeroVector<TFloat>(ns_, curr_stateerr);
606 ZeroVector<TFloat>(ns_, outputerr);
608 }
else if (softmax_ ==
nullptr) {
611 softmax_->
BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.
get(), outputerr);
621 const float *next_node_gf1 = node_values_[
GF1].
f(t + 1);
622 for (
int i = 0;
i < ns_; ++
i) {
623 curr_stateerr[
i] *= next_node_gf1[
i];
626 if (
Is2D() && t + 1 < width) {
627 for (
int i = 0;
i < ns_; ++
i) {
628 if (which_fg_[t + 1][
i] != 1) {
629 curr_stateerr[
i] = 0.0;
633 const float *right_node_gfs = node_values_[
GFS].
f(down_pos);
634 const TFloat *right_stateerr = stateerr[mod_t];
635 for (
int i = 0;
i < ns_; ++
i) {
636 if (which_fg_[down_pos][
i] == 2) {
637 curr_stateerr[
i] += right_stateerr[
i] * right_node_gfs[
i];
644 ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr);
646 if (t + 10 > width) {
648 for (
int i = 0;
i < ns_; ++
i)
649 tprintf(
" %g,%g,%g", curr_stateerr[
i], outputerr[
i], curr_sourceerr[
ni_ + nf_ +
i]);
657 node_values_[
CI].FuncMultiply3<
GPrime>(t, node_values_[
GI], t, curr_stateerr, gate_errors[
CI]);
664 node_values_[
GI].FuncMultiply3<
FPrime>(t, node_values_[
CI], t, curr_stateerr, gate_errors[
GI]);
672 node_values_[
GF1].FuncMultiply3<
FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[
GF1]);
676 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
677 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
683 node_values_[
GFS].FuncMultiply3<
FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[
GFS]);
687 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
688 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
703 sourceerr_temps[
GO], sourceerr_temps[
GFS], curr_sourceerr);
707 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
708 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
712 for (
int w = 0; w <
WT_COUNT; ++w) {
719 source_t.
Init(na_, width, scratch);
721 state_t.
Init(ns_, width, scratch);
724# pragma omp parallel for num_threads(GFS) if (!Is2D())
726 for (
int w = 0; w <
WT_COUNT; ++w) {
732 if (softmax_ !=
nullptr) {
740void LSTM::Update(
float learning_rate,
float momentum,
float adam_beta,
int num_samples) {
744 for (
int w = 0; w <
WT_COUNT; ++w) {
748 gate_weights_[w].
Update(learning_rate, momentum, adam_beta, num_samples);
750 if (softmax_ !=
nullptr) {
751 softmax_->
Update(learning_rate, momentum, adam_beta, num_samples);
763 const LSTM *lstm =
static_cast<const LSTM *
>(&other);
764 for (
int w = 0; w <
WT_COUNT; ++w) {
770 if (softmax_ !=
nullptr) {
780 for (
int w = 0; w <
WT_COUNT; ++w) {
784 tprintf(
"Gate %d, inputs\n", w);
785 for (
int i = 0;
i <
ni_; ++
i) {
787 for (
int s = 0; s < ns_; ++s) {
788 tprintf(
" %g", gate_weights_[w].GetWeights(s)[
i]);
792 tprintf(
"Gate %d, outputs\n", w);
795 for (
int s = 0; s < ns_; ++s) {
796 tprintf(
" %g", gate_weights_[w].GetWeights(s)[
i]);
801 for (
int s = 0; s < ns_; ++s) {
802 tprintf(
" %g", gate_weights_[w].GetWeights(s)[na_]);
811 for (
int w = 0; w <
WT_COUNT; ++w) {
815 tprintf(
"Gate %d, inputs\n", w);
816 for (
int i = 0;
i <
ni_; ++
i) {
818 for (
int s = 0; s < ns_; ++s) {
819 tprintf(
" %g", gate_weights_[w].GetDW(s,
i));
823 tprintf(
"Gate %d, outputs\n", w);
826 for (
int s = 0; s < ns_; ++s) {
827 tprintf(
" %g", gate_weights_[w].GetDW(s,
i));
832 for (
int s = 0; s < ns_; ++s) {
833 tprintf(
" %g", gate_weights_[w].GetDW(s, na_));
842void LSTM::ResizeForward(
const NetworkIO &input) {
844 source_.
Resize(input, rounded_inputs);
848 for (
int w = 0; w <
WT_COUNT; ++w) {
#define END_PARALLEL_IF_OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
#define SECTION_IF_OPENMP
void CopyVector(unsigned n, const TFloat *src, TFloat *dest)
void tprintf(const char *format,...)
void SumVectors(int n, const TFloat *v1, const TFloat *v2, const TFloat *v3, const TFloat *v4, const TFloat *v5, TFloat *sum)
void MultiplyAccumulate(int n, const TFloat *u, const TFloat *v, TFloat *out)
@ NT_LSTM_SOFTMAX_ENCODED
void CodeInBinary(int n, int nf, TFloat *vec)
void AccumulateVector(int n, const TFloat *src, TFloat *dest)
void ClipVector(int n, T lower, T upper, T *vec)
void MultiplyVectorsInPlace(int n, const TFloat *src, TFloat *inout)
void ResizeNoInit(int size1, int size2, int pad=0)
int RoundOutputs(int size) const
static const IntSimdMatrix * intSimdMatrix
bool DeSerialize(std::string &data)
bool Serialize(const std::string &data)
void ForwardTimeStep(int t, TFloat *output_line)
void DebugWeights() override
void FinishBackward(const TransposedArray &errors_t)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void SetEnableTraining(TrainingState state) override
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
int InitWeights(float range, TRand *randomizer) override
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void ConvertToInt() override
StaticShape OutputShape(const StaticShape &input_shape) const override
bool Serialize(TFile *fp) const override
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
TESS_API LSTM(const std::string &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
int InitWeights(float range, TRand *randomizer) override
void DebugWeights() override
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
bool DeSerialize(TFile *fp) override
bool Serialize(TFile *fp) const override
void ConvertToInt() override
void SetEnableTraining(TrainingState state) override
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
StaticShape OutputShape(const StaticShape &input_shape) const override
void DisplayForward(const NetworkIO &matrix)
void DisplayBackward(const NetworkIO &matrix)
static Network * CreateFromFile(TFile *fp)
virtual bool Serialize(TFile *fp) const
bool TestFlag(NetworkFlags flag) const
virtual void SetRandomizer(TRand *randomizer)
void Resize(const NetworkIO &src, int num_features)
void WriteTimeStepPart(int t, int offset, int num_features, const TFloat *input)
void ResizeFloat(const NetworkIO &src, int num_features)
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
void WriteTimeStep(int t, const TFloat *input)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
void Print(int num) const
void ReadTimeStep(int t, TFloat *output) const
void Func2Multiply3(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
void Transpose(TransposedArray *dest) const
const StrideMap & stride_map() const
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
const int8_t * i(int t) const
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void Init(int, int reserve, NetworkScratch *scratch)
TransposedArray * get() const
void Init(int size1, int size2, NetworkScratch *scratch)
void set_depth(int value)
void set_width(int value)
int Size(FlexDimensions dimension) const
int index(FlexDimensions dimension) const
bool AddOffset(int offset, FlexDimensions dimension)
bool IsLast(FlexDimensions dimension) const
void WriteStrided(int t, const float *data)
void PrintUnTransposed(int num)
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
void Debug2D(const char *msg)
void VectorDotMatrix(const TFloat *u, TFloat *v) const
void MatrixDotVector(const TFloat *u, TFloat *v) const
int RoundInputs(int size) const
void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const