tesseract v5.3.3.20231005
networkscratch.h
Go to the documentation of this file.
1
2// File: networkscratch.h
3// Description: Scratch space for Network layers that hides distinction
4// between float/int implementations.
5// Author: Ray Smith
6//
7// (C) Copyright 2014, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
20#define TESSERACT_LSTM_NETWORKSCRATCH_H_
21
22#include <mutex>
23#include "matrix.h"
24#include "networkio.h"
25
26namespace tesseract {
27
28// Generic scratch space for network layers. Provides NetworkIO that can store
29// a complete set (over time) of intermediates, and vector<float>
30// scratch space that auto-frees after use. The aim here is to provide a set
31// of temporary buffers to network layers that can be reused between layers
32// and don't have to be reallocated on each call.
34public:
35 NetworkScratch() : int_mode_(false) {}
36 ~NetworkScratch() = default;
37
38 // Sets the network representation. If the representation is integer, then
39 // default (integer) NetworkIOs are separated from the always-float variety.
40 // This saves memory by having separate int-specific and float-specific
41 // stacks. If the network representation is float, then all NetworkIOs go
42 // to the float stack.
43 void set_int_mode(bool int_mode) {
44 int_mode_ = int_mode;
45 }
46
47 // Class that acts like a NetworkIO (by having an implicit cast operator),
48 // yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
49 // and knows how to unstack the borrowed pointers on destruction.
50 class IO {
51 public:
52 // The NetworkIO should be sized after construction.
53 IO(const NetworkIO &src, NetworkScratch *scratch)
54 : int_mode_(scratch->int_mode_ && src.int_mode()), scratch_space_(scratch) {
55 network_io_ =
56 int_mode_ ? scratch_space_->int_stack_.Borrow() : scratch_space_->float_stack_.Borrow();
57 }
58 // Default constructor for arrays. Use one of the Resize functions
59 // below to initialize and size.
60 IO() : int_mode_(false), network_io_(nullptr), scratch_space_(nullptr) {}
61
62 ~IO() {
63 if (scratch_space_ == nullptr) {
64 ASSERT_HOST(network_io_ == nullptr);
65 } else if (int_mode_) {
66 scratch_space_->int_stack_.Return(network_io_);
67 } else {
68 scratch_space_->float_stack_.Return(network_io_);
69 }
70 }
71 // Resizes the array (and stride), avoiding realloc if possible, to the
72 // size from various size specs:
73 // Same time size, given number of features.
74 void Resize(const NetworkIO &src, int num_features, NetworkScratch *scratch) {
75 if (scratch_space_ == nullptr) {
76 int_mode_ = scratch->int_mode_ && src.int_mode();
77 scratch_space_ = scratch;
78 network_io_ =
79 int_mode_ ? scratch_space_->int_stack_.Borrow() : scratch_space_->float_stack_.Borrow();
80 }
81 network_io_->Resize(src, num_features);
82 }
83 // Resizes to a specific size as a temp buffer. No batches, no y-dim.
84 void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch) {
85 if (scratch_space_ == nullptr) {
86 int_mode_ = scratch->int_mode_ && int_mode;
87 scratch_space_ = scratch;
88 network_io_ =
89 int_mode_ ? scratch_space_->int_stack_.Borrow() : scratch_space_->float_stack_.Borrow();
90 }
91 network_io_->Resize2d(int_mode, width, num_features);
92 }
93 // Resize forcing a float representation with the width of src and the given
94 // number of features.
95 void ResizeFloat(const NetworkIO &src, int num_features, NetworkScratch *scratch) {
96 if (scratch_space_ == nullptr) {
97 int_mode_ = false;
98 scratch_space_ = scratch;
99 network_io_ = scratch_space_->float_stack_.Borrow();
100 }
101 network_io_->ResizeFloat(src, num_features);
102 }
103
104 // Returns a ref to a NetworkIO that enables *this to be treated as if
105 // it were just a NetworkIO*.
107 return *network_io_;
108 }
110 return network_io_;
111 }
112 operator NetworkIO *() {
113 return network_io_;
114 }
115
116 private:
117 // True if this is from the always-float stack, otherwise the default stack.
118 bool int_mode_;
119 // The NetworkIO that we have borrowed from the scratch_space_.
120 NetworkIO *network_io_;
121 // The source scratch_space_. Borrowed pointer, used to free the
122 // NetworkIO. Don't delete!
123 NetworkScratch *scratch_space_;
124 }; // class IO.
125
126 // Class that acts like a fixed array of float, yet actually uses space
127 // from a vector<float> in the source NetworkScratch, and knows how
128 // to unstack the borrowed vector on destruction.
129 class FloatVec {
130 public:
131 // The array will have size elements in it, uninitialized.
132 FloatVec(int size, NetworkScratch *scratch) : vec_(nullptr), scratch_space_(scratch) {
133 Init(size, scratch);
134 }
135 // Default constructor is for arrays. Use Init to setup.
136 FloatVec() : vec_(nullptr), data_(nullptr), scratch_space_(nullptr) {}
138 if (scratch_space_ != nullptr) {
139 scratch_space_->vec_stack_.Return(vec_);
140 }
141 }
142
143 void Init(int /*size*/, int reserve, NetworkScratch *scratch) {
144 if (scratch_space_ != nullptr && vec_ != nullptr) {
145 scratch_space_->vec_stack_.Return(vec_);
146 }
147 scratch_space_ = scratch;
148 vec_ = scratch_space_->vec_stack_.Borrow();
149 // TODO: optimize.
150 vec_->resize(reserve);
151 data_ = &(*vec_)[0];
152 }
153
154 void Init(int size, NetworkScratch *scratch) {
155 Init(size, size, scratch);
156 }
157
158 // Use the cast operator instead of operator[] so the FloatVec can be used
159 // as a TFloat* argument to a function call.
160 operator TFloat *() const {
161 return data_;
162 }
164 return data_;
165 }
166
167 private:
168 // Vector borrowed from the scratch space. Use Return to free it.
169 std::vector<TFloat> *vec_;
170 // Short-cut pointer to the underlying array.
171 TFloat *data_;
172 // The source scratch_space_. Borrowed pointer, used to free the
173 // vector. Don't delete!
174 NetworkScratch *scratch_space_;
175 }; // class FloatVec
176
177 // Class that acts like a 2-D array of TFloat, yet actually uses space
178 // from the source NetworkScratch, and knows how to unstack the borrowed
179 // array on destruction.
181 public:
182 // Default constructor is for arrays. Use Init to setup.
183 GradientStore() : array_(nullptr), scratch_space_(nullptr) {}
185 if (scratch_space_ != nullptr) {
186 scratch_space_->array_stack_.Return(array_);
187 }
188 }
189
190 void Init(int size1, int size2, NetworkScratch *scratch) {
191 if (scratch_space_ != nullptr && array_ != nullptr) {
192 scratch_space_->array_stack_.Return(array_);
193 }
194 scratch_space_ = scratch;
195 array_ = scratch_space_->array_stack_.Borrow();
196 array_->Resize(size1, size2, 0.0);
197 }
198
199 // Accessors to get to the underlying TransposedArray.
201 return array_;
202 }
203 const TransposedArray &operator*() const {
204 return *array_;
205 }
206
207 private:
208 // Array borrowed from the scratch space. Use Return to free it.
209 TransposedArray *array_;
210 // The source scratch_space_. Borrowed pointer, used to free the
211 // vector. Don't delete!
212 NetworkScratch *scratch_space_;
213 }; // class GradientStore
214
215 // Class that does the work of holding a stack of objects, a stack pointer
216 // and a vector of in-use flags, so objects can be returned out of order.
217 // It is safe to attempt to Borrow/Return in multiple threads.
218 template <typename T>
219 class Stack {
220 public:
221 Stack() = default;
222
224 for (auto data : stack_) {
225 delete data;
226 }
227 }
228
229 // Lends out the next free item, creating one if none available, sets
230 // the used flags and increments the stack top.
231 T *Borrow() {
232 std::lock_guard<std::mutex> lock(mutex_);
233 if (stack_top_ == stack_.size()) {
234 stack_.push_back(new T);
235 flags_.push_back(false);
236 }
237 flags_[stack_top_] = true;
238 return stack_[stack_top_++];
239 }
240 // Takes back the given item, and marks it free. Item does not have to be
241 // the most recently lent out, but free slots don't get re-used until the
242 // blocking item is returned. The assumption is that there will only be
243 // small, temporary variations from true stack use. (Determined by the order
244 // of destructors within a local scope.)
245 void Return(T *item) {
246 std::lock_guard<std::mutex> lock(mutex_);
247 // Linear search will do.
248 int index = stack_top_;
249 while (--index >= 0 && stack_[index] != item) {
250 }
251 if (index >= 0) {
252 flags_[index] = false;
253 }
254 while (stack_top_ > 0 && !flags_[stack_top_ - 1]) {
255 --stack_top_;
256 }
257 }
258
259 private:
260 std::vector<T *> stack_;
261 std::vector<bool> flags_;
262 unsigned stack_top_ = 0;
263 std::mutex mutex_;
264 }; // class Stack.
265
266private:
267 // If true, the network weights are int8_t, if false, float.
268 bool int_mode_;
269 // Stacks of NetworkIO and vector<float>. Once allocated, they are not
270 // deleted until the NetworkScratch is deleted.
271 Stack<NetworkIO> int_stack_;
272 Stack<NetworkIO> float_stack_;
273 Stack<std::vector<TFloat>> vec_stack_;
274 Stack<TransposedArray> array_stack_;
275};
276
277} // namespace tesseract.
278
279#endif // TESSERACT_LSTM_NETWORKSCRATCH_H_
#define ASSERT_HOST(x)
Definition: errcode.h:54
double TFloat
Definition: tesstypes.h:39
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:110
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:44
bool int_mode() const
Definition: networkio.h:122
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:51
void Resize2d(bool int_mode, int width, int num_features)
Definition: networkio.cpp:35
void set_int_mode(bool int_mode)
void Resize(const NetworkIO &src, int num_features, NetworkScratch *scratch)
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void ResizeFloat(const NetworkIO &src, int num_features, NetworkScratch *scratch)
IO(const NetworkIO &src, NetworkScratch *scratch)
void Init(int size, NetworkScratch *scratch)
FloatVec(int size, NetworkScratch *scratch)
void Init(int, int reserve, NetworkScratch *scratch)
const TransposedArray & operator*() const
void Init(int size1, int size2, NetworkScratch *scratch)