19#ifndef TESSERACT_LSTM_PLUMBING_H_
20#define TESSERACT_LSTM_PLUMBING_H_
31 explicit Plumbing(
const std::string &name);
33 for (
auto data : stack_) {
40 return stack_[0]->InputShape();
42 std::string
spec()
const override {
43 return "Sub-classes of Plumbing must implement spec()!";
58 void SetNetworkFlags(uint32_t flags)
override;
65 int InitWeights(
float range,
TRand *randomizer)
override;
68 int RemapOutputs(
int old_no,
const std::vector<int> &code_map)
override;
71 void ConvertToInt()
override;
76 void SetRandomizer(
TRand *randomizer)
override;
79 virtual void AddToStack(
Network *network);
84 bool SetupNeedsBackprop(
bool needs_backprop)
override;
92 int XScaleFactor()
const override;
96 void CacheXScaleFactor(
int factor)
override;
99 void DebugWeights()
override;
102 const std::vector<Network *> &
stack()
const {
106 void EnumerateLayers(
const std::string *prefix, std::vector<std::string> &layers)
const;
108 Network *GetLayer(
const char *
id)
const;
111 const float *lr_ptr = LayerLearningRatePtr(
id);
117 float *lr_ptr = LayerLearningRatePtr(
id);
124 float *lr_ptr = LayerLearningRatePtr(
id);
126 *lr_ptr = learning_rate;
130 float *LayerLearningRatePtr(
const char *
id);
139 void Update(
float learning_rate,
float momentum,
float adam_beta,
int num_samples)
override;
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
bool Serialize(FILE *fp, const std::vector< T > &data)
std::string spec() const override
const std::vector< Network * > & stack() const
float LayerLearningRate(const char *id)
StaticShape InputShape() const override
std::vector< Network * > stack_
void SetLayerLearningRate(const char *id, float learning_rate)
void ScaleLayerLearningRate(const char *id, double factor)
std::vector< float > learning_rates_
bool IsPlumbingType() const override