416 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_,
ni_);
419 NetworkScratch::FloatVec outputerr;
420 outputerr.Init(ns_, scratch);
422 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
423 curr_stateerr.Init(ns_, scratch);
424 curr_sourceerr.Init(na_, scratch);
425 ZeroVector<double>(ns_, curr_stateerr);
426 ZeroVector<double>(na_, curr_sourceerr);
428 NetworkScratch::FloatVec gate_errors[
WT_COUNT];
429 for (
int g = 0; g <
WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
435 stateerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
436 sourceerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
437 for (
int t = 0; t < buf_width; ++t) {
438 stateerr[t].Init(ns_, scratch);
439 sourceerr[t].Init(na_, scratch);
440 ZeroVector<double>(ns_, stateerr[t]);
441 ZeroVector<double>(na_, sourceerr[t]);
445 NetworkScratch::FloatVec sourceerr_temps[
WT_COUNT];
447 sourceerr_temps[w].Init(na_, scratch);
448 int width = input_width_;
450 NetworkScratch::GradientStore gate_errors_t[
WT_COUNT];
451 for (
int w = 0; w <
WT_COUNT; ++w) {
452 gate_errors_t[w].Init(ns_, width, scratch);
455 NetworkScratch::FloatVec softmax_errors;
456 NetworkScratch::GradientStore softmax_errors_t;
457 if (softmax_ != NULL) {
458 softmax_errors.Init(
no_, scratch);
459 softmax_errors_t.Init(
no_, width, scratch);
461 double state_clip =
Is2D() ? 9.0 : 4.0;
464 fwd_deltas.Print(10);
466 StrideMap::Index dest_index(input_map_);
467 dest_index.InitToLast();
469 StrideMap::Index src_index(fwd_deltas.stride_map());
470 src_index.InitToLast();
472 int t = dest_index.t();
473 bool at_last_x = dest_index.IsLast(
FD_WIDTH);
480 StrideMap::Index up_index(dest_index);
481 if (up_index.AddOffset(-1,
FD_HEIGHT)) up_pos = up_index.t();
484 StrideMap::Index down_index(dest_index);
485 if (down_index.AddOffset(1,
FD_HEIGHT)) down_pos = down_index.t();
489 int mod_t =
Modulo(t, buf_width);
492 ZeroVector<double>(na_, curr_sourceerr);
493 ZeroVector<double>(ns_, curr_stateerr);
498 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
499 src_index.Decrement();
501 ZeroVector<double>(ns_, outputerr);
503 }
else if (softmax_ == NULL) {
504 fwd_deltas.ReadTimeStep(t, outputerr);
507 softmax_errors_t.get(), outputerr);
515 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
516 for (
int i = 0; i < ns_; ++i) {
517 curr_stateerr[i] *= next_node_gf1[i];
520 if (
Is2D() && t + 1 < width) {
521 for (
int i = 0; i < ns_; ++i) {
522 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
525 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
526 const double* right_stateerr = stateerr[mod_t];
527 for (
int i = 0; i < ns_; ++i) {
528 if (which_fg_[down_pos][i] == 2) {
529 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
537 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
539 if (t + 10 > width) {
541 for (
int i = 0; i < ns_; ++i)
542 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
543 curr_sourceerr[
ni_ + nf_ + i]);
551 node_values_[
CI].FuncMultiply3<GPrime>(t, node_values_[
GI], t,
552 curr_stateerr, gate_errors[
CI]);
555 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[CI]);
559 node_values_[
GI].FuncMultiply3<FPrime>(t, node_values_[
CI], t,
560 curr_stateerr, gate_errors[
GI]);
563 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[GI]);
568 node_values_[
GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
572 sourceerr_temps[GF1]);
574 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[GF1][0]));
575 memset(sourceerr_temps[GF1], 0, na_ *
sizeof(*sourceerr_temps[GF1]));
577 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[GF1]);
581 node_values_[
GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
585 sourceerr_temps[GFS]);
587 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[GFS][0]));
588 memset(sourceerr_temps[GFS], 0, na_ *
sizeof(*sourceerr_temps[GFS]));
590 if (
Is2D()) gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[GFS]);
594 state_.Func2Multiply3<HFunc, FPrime>(node_values_[
GO], t, outputerr,
598 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[GO]);
601 SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
602 sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
604 back_deltas->WriteTimeStep(t, curr_sourceerr);
607 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
608 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
610 }
while (dest_index.Decrement());
612 for (
int w = 0; w <
WT_COUNT; ++w) {
614 gate_errors_t[w].get()->PrintUnTransposed(10);
618 NetworkScratch::GradientStore source_t, state_t;
619 source_t.Init(na_, width, scratch);
621 state_t.Init(ns_, width, scratch);
622 state_.Transpose(state_t.get());
624 #pragma omp parallel for num_threads(GFS) if (!Is2D()) 626 for (
int w = 0; w <
WT_COUNT; ++w) {
627 if (w == GFS && !
Is2D())
continue;
630 if (softmax_ != NULL) {
void CopyVector(int n, const double *src, double *dest)
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void ClipVector(int n, T lower, T upper, T *vec)
#define SECTION_IF_OPENMP
int Size(FlexDimensions dimension) const
#define PARALLEL_IF_OPENMP(__num_threads)
void FinishBackward(const TransposedArray &errors_t)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
void VectorDotMatrix(const double *u, double *v) const
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
const char * string() const
void Transpose(TransposedArray *dest) const
void DisplayBackward(const NetworkIO &matrix)
#define END_PARALLEL_IF_OPENMP
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
void init_to_size(int size, T t)
void AccumulateVector(int n, const double *src, double *dest)