19#ifndef TESSERACT_LSTM_TFNETWORK_H_
20#define TESSERACT_LSTM_TFNETWORK_H_
22#ifdef INCLUDE_TENSORFLOW
29# include "tensorflow/core/framework/graph.pb.h"
30# include "tensorflow/core/public/session.h"
35class TFNetwork :
public Network {
37 explicit TFNetwork(
const char *name);
38 virtual ~TFNetwork() =
default;
41 StaticShape InputShape()
const override {
46 StaticShape OutputShape(
const StaticShape &input_shape)
const override {
50 std::string spec()
const override {
56 int InitFromProtoStr(
const std::string &proto_str);
59 int num_classes()
const {
60 return output_shape_.depth();
72 void Forward(
bool debug,
const NetworkIO &input,
const TransposedArray *input_transpose,
73 NetworkScratch *scratch, NetworkIO *
output)
override;
78 bool Backward(
bool debug,
const NetworkIO &fwd_deltas, NetworkScratch *scratch,
79 NetworkIO *back_deltas)
override {
80 tprintf(
"Must override Network::Backward for type %d\n", type_);
84 void DebugWeights()
override {
85 tprintf(
"Must override Network::DebugWeights for type %d\n", type_);
93 StaticShape input_shape_;
95 StaticShape output_shape_;
97 std::unique_ptr<tensorflow::Session> session_;
99 TFNetworkModel model_proto_;
void tprintf(const char *format,...)
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
bool Serialize(FILE *fp, const std::vector< T > &data)