20# include "config_auto.h"
29#include <allheaders.h>
41#ifdef INCLUDE_TENSORFLOW
48#ifndef GRAPHICS_DISABLED
63static char const *
const kTypeNames[
NT_COUNT] = {
65 "Convolve",
"Maxpool",
66 "Parallel",
"Replicated",
67 "ParBidiLSTM",
"DepParUDLSTM",
68 "Par2dLSTM",
"Series",
69 "Reconfig",
"RTLReversed",
70 "TTBReversed",
"XYTranspose",
72 "Logistic",
"LinLogistic",
75 "Softmax",
"SoftmaxNoCTC",
76 "LSTMSoftmax",
"LSTMBinarySoftmax",
83 , needs_to_backprop_(true)
88 , forward_win_(nullptr)
89 , backward_win_(nullptr)
90 , randomizer_(nullptr) {}
94 , needs_to_backprop_(true)
100 , forward_win_(nullptr)
101 , backward_win_(nullptr)
102 , randomizer_(nullptr) {}
163 std::string type_name = kTypeNames[
type_];
187 uint32_t length =
name_.length();
200 std::string type_name;
204 for (data = 0; data <
NT_COUNT && type_name != kTypeNames[data]; ++data) {
207 tprintf(
"Invalid network layer type:%s\n", type_name.c_str());
221 int32_t network_flags;
228 type = getNetworkType(fp);
258 network =
new Input(
name.c_str(), ni, no);
264 network =
new LSTM(
name.c_str(), ni, no, no,
false,
type);
290#ifdef INCLUDE_TENSORFLOW
291 network =
new TFNetwork(
name.c_str());
293 tprintf(
"TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
329#ifndef GRAPHICS_DISABLED
343 std::string window_name =
name_ +
"-back";
352 if (*window ==
nullptr) {
353 int min_size = std::min(width, height);
369 *window =
new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords);
370 tprintf(
"Created window %s of size %d, %d\n", window_name, width, height);
379 int height = pixGetHeight(pix);
380 window->
Draw(pix, 0, 0);
void tprintf(const char *format,...)
@ NT_LSTM_SOFTMAX_ENCODED
double SignedRand(double range)
bool DeSerialize(std::string &data)
bool Serialize(const std::string &data)
const std::string & name() const
virtual bool SetupNeedsBackprop(bool needs_backprop)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
void DisplayForward(const NetworkIO &matrix)
virtual bool DeSerialize(TFile *fp)=0
void DisplayBackward(const NetworkIO &matrix)
virtual void SetEnableTraining(TrainingState state)
bool needs_to_backprop() const
ScrollView * forward_win_
static Network * CreateFromFile(TFile *fp)
virtual bool Serialize(TFile *fp) const
ScrollView * backward_win_
static int DisplayImage(Image pix, ScrollView *window)
TFloat Random(TFloat range)
virtual int InitWeights(float range, TRand *randomizer)
virtual void SetNetworkFlags(uint32_t flags)
virtual void SetRandomizer(TRand *randomizer)
void Draw(Image image, int x_pos, int y_pos)