508 {
509#ifndef GRAPHICS_DISABLED
510 if (debug) {
512 }
513#endif
514 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_,
ni_);
515
516
517 NetworkScratch::FloatVec outputerr;
518 outputerr.Init(ns_, scratch);
519
520 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
521 curr_stateerr.Init(ns_, scratch);
522 curr_sourceerr.Init(na_, scratch);
523 ZeroVector<TFloat>(ns_, curr_stateerr);
524 ZeroVector<TFloat>(na_, curr_sourceerr);
525
526 NetworkScratch::FloatVec gate_errors[
WT_COUNT];
527 for (auto &gate_error : gate_errors) {
528 gate_error.Init(ns_, scratch);
529 }
530
531
533 std::vector<NetworkScratch::FloatVec> stateerr, sourceerr;
535 stateerr.resize(buf_width);
536 sourceerr.resize(buf_width);
537 for (int t = 0; t < buf_width; ++t) {
538 stateerr[t].Init(ns_, scratch);
539 sourceerr[t].Init(na_, scratch);
540 ZeroVector<TFloat>(ns_, stateerr[t]);
541 ZeroVector<TFloat>(na_, sourceerr[t]);
542 }
543 }
544
545 NetworkScratch::FloatVec sourceerr_temps[
WT_COUNT];
546 for (auto &sourceerr_temp : sourceerr_temps) {
547 sourceerr_temp.Init(na_, scratch);
548 }
549 int width = input_width_;
550
551 NetworkScratch::GradientStore gate_errors_t[
WT_COUNT];
552 for (auto &w : gate_errors_t) {
553 w.Init(ns_, width, scratch);
554 }
555
556 NetworkScratch::FloatVec softmax_errors;
557 NetworkScratch::GradientStore softmax_errors_t;
558 if (softmax_ != nullptr) {
559 softmax_errors.Init(
no_, scratch);
560 softmax_errors_t.Init(
no_, width, scratch);
561 }
563#if DEBUG_DETAIL > 1
565 fwd_deltas.Print(10);
566#endif
567 StrideMap::Index dest_index(input_map_);
568 dest_index.InitToLast();
569
570 StrideMap::Index src_index(fwd_deltas.stride_map());
571 src_index.InitToLast();
572 do {
573 int t = dest_index.t();
574 bool at_last_x = dest_index.IsLast(
FD_WIDTH);
575
576
577 int up_pos = -1;
578 int down_pos = -1;
581 StrideMap::Index up_index(dest_index);
583 up_pos = up_index.t();
584 }
585 }
587 StrideMap::Index down_index(dest_index);
588 if (down_index.AddOffset(1,
FD_HEIGHT)) {
589 down_pos = down_index.t();
590 }
591 }
592 }
593
594 int mod_t =
Modulo(t, buf_width);
595
596 if (at_last_x) {
597 ZeroVector<TFloat>(na_, curr_sourceerr);
598 ZeroVector<TFloat>(ns_, curr_stateerr);
599 }
600
603 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
604 src_index.Decrement();
605 } else {
606 ZeroVector<TFloat>(ns_, outputerr);
607 }
608 } else if (softmax_ == nullptr) {
609 fwd_deltas.ReadTimeStep(t, outputerr);
610 } else {
611 softmax_->
BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr);
612 }
613 if (!at_last_x) {
615 }
616 if (down_pos >= 0) {
618 }
619
620 if (!at_last_x) {
621 const float *next_node_gf1 = node_values_[
GF1].
f(t + 1);
622 for (
int i = 0;
i < ns_; ++
i) {
623 curr_stateerr[
i] *= next_node_gf1[
i];
624 }
625 }
626 if (
Is2D() && t + 1 < width) {
627 for (
int i = 0;
i < ns_; ++
i) {
628 if (which_fg_[t + 1][
i] != 1) {
629 curr_stateerr[
i] = 0.0;
630 }
631 }
632 if (down_pos >= 0) {
633 const float *right_node_gfs = node_values_[
GFS].
f(down_pos);
634 const TFloat *right_stateerr = stateerr[mod_t];
635 for (
int i = 0;
i < ns_; ++
i) {
636 if (which_fg_[down_pos][
i] == 2) {
637 curr_stateerr[
i] += right_stateerr[
i] * right_node_gfs[
i];
638 }
639 }
640 }
641 }
643
644 ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr);
645#if DEBUG_DETAIL > 1
646 if (t + 10 > width) {
648 for (
int i = 0;
i < ns_; ++
i)
649 tprintf(
" %g,%g,%g", curr_stateerr[
i], outputerr[
i], curr_sourceerr[
ni_ + nf_ +
i]);
651 }
652#endif
653
655
656
657 node_values_[
CI].FuncMultiply3<GPrime>(t, node_values_[
GI], t, curr_stateerr, gate_errors[
CI]);
660 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[
CI]);
661
663
664 node_values_[
GI].FuncMultiply3<FPrime>(t, node_values_[
CI], t, curr_stateerr, gate_errors[
GI]);
667 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[
GI]);
668
670
671 if (t > 0) {
672 node_values_[
GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[
GF1]);
675 } else {
676 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
677 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
678 }
679 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[
GF1]);
680
681
682 if (up_pos >= 0) {
683 node_values_[
GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[
GFS]);
686 } else {
687 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
688 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
689 }
691 gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[
GFS]);
692 }
693
695
696 state_.Func2Multiply3<HFunc, FPrime>(node_values_[
GO], t, outputerr, gate_errors[
GO]);
699 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[
GO]);
701
703 sourceerr_temps[
GO], sourceerr_temps[
GFS], curr_sourceerr);
704 back_deltas->WriteTimeStep(t, curr_sourceerr);
705
707 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
708 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
709 }
710 } while (dest_index.Decrement());
711#if DEBUG_DETAIL > 2
712 for (
int w = 0; w <
WT_COUNT; ++w) {
714 gate_errors_t[w].get()->PrintUnTransposed(10);
715 }
716#endif
717
718 NetworkScratch::GradientStore source_t, state_t;
719 source_t.Init(na_, width, scratch);
721 state_t.Init(ns_, width, scratch);
722 state_.Transpose(state_t.get());
723#ifdef _OPENMP
724# pragma omp parallel for num_threads(GFS) if (!Is2D())
725#endif
726 for (
int w = 0; w <
WT_COUNT; ++w) {
728 continue;
729 }
731 }
732 if (softmax_ != nullptr) {
734 }
736}
#define END_PARALLEL_IF_OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
#define SECTION_IF_OPENMP
void CopyVector(unsigned n, const TFloat *src, TFloat *dest)
void SumVectors(int n, const TFloat *v1, const TFloat *v2, const TFloat *v3, const TFloat *v4, const TFloat *v5, TFloat *sum)
void AccumulateVector(int n, const TFloat *src, TFloat *dest)
void ClipVector(int n, T lower, T upper, T *vec)
void FinishBackward(const TransposedArray &errors_t)
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
void DisplayBackward(const NetworkIO &matrix)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
void Transpose(TransposedArray *dest) const
int Size(FlexDimensions dimension) const
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void VectorDotMatrix(const TFloat *u, TFloat *v) const