18#ifndef TESSERACT_LSTM_NETWORK_H_
19#define TESSERACT_LSTM_NETWORK_H_
117 return needs_to_backprop_;
140 const std::string &
name()
const {
143 virtual std::string
spec()
const {
147 return (network_flags_ & flag) != 0;
171 virtual void SetNetworkFlags(uint32_t flags);
178 virtual int InitWeights(
float range,
TRand *randomizer);
191 [[maybe_unused]]
const std::vector<int> &code_map) {
201 virtual void SetRandomizer(
TRand *randomizer);
206 virtual bool SetupNeedsBackprop(
bool needs_backprop);
235 virtual void Update([[maybe_unused]]
float learning_rate,
236 [[maybe_unused]]
float momentum,
237 [[maybe_unused]]
float adam_beta,
238 [[maybe_unused]]
int num_samples) {}
243 [[maybe_unused]]
TFloat *same,
244 [[maybe_unused]]
TFloat *changed)
const {}
283 void DisplayForward(
const NetworkIO &matrix);
285 void DisplayBackward(
const NetworkIO &matrix);
288 static void ClearWindow(
bool tess_coords,
const char *window_name,
int width,
bool Serialize(FILE *fp, const std::vector< T > &data)
@ NT_LSTM_SOFTMAX_ENCODED
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
virtual int XScaleFactor() const
const std::string & name() const
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
virtual bool DeSerialize(TFile *fp)=0
virtual bool IsPlumbingType() const
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool needs_to_backprop() const
ScrollView * forward_win_
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
virtual void CacheXScaleFactor(int factor)
ScrollView * backward_win_
virtual void DebugWeights()=0
virtual StaticShape OutputShape(const StaticShape &input_shape) const
bool TestFlag(NetworkFlags flag) const
virtual std::string spec() const
virtual ~Network()=default
virtual void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const
virtual void ConvertToInt()
virtual StaticShape InputShape() const
void set_depth(int value)