tesseract v5.3.3.20231005
tfnetwork.h
Go to the documentation of this file.
1
2// File: tfnetwork.h
3// Description: Encapsulation of an entire tensorflow graph as a
4// Tesseract Network.
5// Author: Ray Smith
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_TFNETWORK_H_
20#define TESSERACT_LSTM_TFNETWORK_H_
21
22#ifdef INCLUDE_TENSORFLOW
23
24# include <memory>
25# include <string>
26
27# include "network.h"
28# include "static_shape.h"
29# include "tensorflow/core/framework/graph.pb.h"
30# include "tensorflow/core/public/session.h"
31# include "tfnetwork.pb.h"
32
33namespace tesseract {
34
35class TFNetwork : public Network {
36public:
37 explicit TFNetwork(const char *name);
38 virtual ~TFNetwork() = default;
39
40 // Returns the required shape input to the network.
41 StaticShape InputShape() const override {
42 return input_shape_;
43 }
44 // Returns the shape output from the network given an input shape (which may
45 // be partially unknown ie zero).
46 StaticShape OutputShape(const StaticShape &input_shape) const override {
47 return output_shape_;
48 }
49
50 std::string spec() const override {
51 return spec_;
52 }
53
54 // Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
55 // otherwise the global step of the serialized graph.
56 int InitFromProtoStr(const std::string &proto_str);
57 // The number of classes in this network should be equal to those in the
58 // recoder_ in LSTMRecognizer.
59 int num_classes() const {
60 return output_shape_.depth();
61 }
62
63 // Writes to the given file. Returns false in case of error.
64 // Should be overridden by subclasses, but called by their Serialize.
65 bool Serialize(TFile *fp) const override;
66 // Reads from the given file. Returns false in case of error.
67 // Should be overridden by subclasses, but NOT called by their DeSerialize.
68 bool DeSerialize(TFile *fp) override;
69
70 // Runs forward propagation of activations on the input line.
71 // See Network for a detailed discussion of the arguments.
72 void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
73 NetworkScratch *scratch, NetworkIO *output) override;
74
75private:
76 // Runs backward propagation of errors on the deltas line.
77 // See Network for a detailed discussion of the arguments.
78 bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
79 NetworkIO *back_deltas) override {
80 tprintf("Must override Network::Backward for type %d\n", type_);
81 return false;
82 }
83
84 void DebugWeights() override {
85 tprintf("Must override Network::DebugWeights for type %d\n", type_);
86 }
87
88 int InitFromProto();
89
90 // The original network definition for reference.
91 std::string spec_;
92 // Input tensor parameters.
93 StaticShape input_shape_;
94 // Output tensor parameters.
95 StaticShape output_shape_;
96 // The tensor flow graph is contained in here.
97 std::unique_ptr<tensorflow::Session> session_;
98 // The serialized graph is also contained in here.
99 TFNetworkModel model_proto_;
100};
101
102} // namespace tesseract.
103
104#endif // ifdef INCLUDE_TENSORFLOW
105
106#endif // TESSERACT_TENSORFLOW_TFNETWORK_H_
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:205
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:236