18#ifndef TESSERACT_LSTM_LSTM_H_
19#define TESSERACT_LSTM_LSTM_H_
50 LSTM(
const std::string &
name,
int num_inputs,
int num_states,
int num_outputs,
58 std::string
spec()
const override {
61 spec +=
"Lfx" + std::to_string(ns_);
63 spec +=
"Lfxs" + std::to_string(ns_);
65 spec +=
"LS" + std::to_string(ns_);
67 spec +=
"LE" + std::to_string(ns_);
69 if (softmax_ !=
nullptr) {
84 int RemapOutputs(
int old_no,
const std::vector<int> &code_map)
override;
108 void Update(
float learning_rate,
float momentum,
float adam_beta,
int num_samples)
override;
125 void ResizeForward(
const NetworkIO &input);
@ NT_LSTM_SOFTMAX_ENCODED
std::string spec() const override
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
TESS_API LSTM(const std::string &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
int InitWeights(float range, TRand *randomizer) override
void DebugWeights() override
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
bool DeSerialize(TFile *fp) override
std::string spec() const override
bool Serialize(TFile *fp) const override
void ConvertToInt() override
void SetEnableTraining(TrainingState state) override
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
StaticShape OutputShape(const StaticShape &input_shape) const override
const std::string & name() const