tesseract v5.3.3.20231005
networkbuilder.cpp
Go to the documentation of this file.
1
2// File: networkbuilder.cpp
3// Description: Class to parse the network description language and
4// build a corresponding network.
5// Author: Ray Smith
6//
7// (C) Copyright 2014, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#include "networkbuilder.h"
20
21#include "convolve.h"
22#include "fullyconnected.h"
23#include "input.h"
24#include "lstm.h"
25#include "maxpool.h"
26#include "network.h"
27#include "parallel.h"
28#include "reconfig.h"
29#include "reversed.h"
30#include "series.h"
31#include "unicharset.h"
32
33namespace tesseract {
34
35// Builds a network with a network_spec in the network description
36// language, to recognize a character set of num_outputs size.
37// If append_index is non-negative, then *network must be non-null and the
38// given network_spec will be appended to *network AFTER append_index, with
39// the top of the input *network discarded.
40// Note that network_spec is call by value to allow a non-const char* pointer
41// into the string for BuildFromString.
42// net_flags control network behavior according to the NetworkFlags enum.
43// The resulting network is returned via **network.
44// Returns false if something failed.
45bool NetworkBuilder::InitNetwork(int num_outputs, const char *network_spec, int append_index,
46 int net_flags, float weight_range, TRand *randomizer,
47 Network **network) {
48 NetworkBuilder builder(num_outputs);
49 Series *bottom_series = nullptr;
50 StaticShape input_shape;
51 if (append_index >= 0) {
52 // Split the current network after the given append_index.
53 ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES);
54 auto *series = static_cast<Series *>(*network);
55 Series *top_series = nullptr;
56 series->SplitAt(append_index, &bottom_series, &top_series);
57 if (bottom_series == nullptr || top_series == nullptr) {
58 tprintf("Yikes! Splitting current network failed!!\n");
59 return false;
60 }
61 input_shape = bottom_series->OutputShape(input_shape);
62 delete top_series;
63 }
64 *network = builder.BuildFromString(input_shape, &network_spec);
65 if (*network == nullptr) {
66 return false;
67 }
68 (*network)->SetNetworkFlags(net_flags);
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(false);
71 if (bottom_series != nullptr) {
72 bottom_series->AppendSeries(*network);
73 *network = bottom_series;
74 }
75 (*network)->CacheXScaleFactor((*network)->XScaleFactor());
76 return true;
77}
78
79// Helper skips whitespace.
80static void SkipWhitespace(const char **str) {
81 while (**str == ' ' || **str == '\t' || **str == '\n') {
82 ++*str;
83 }
84}
85
86// Parses the given string and returns a network according to the network
87// description language in networkbuilder.h
88Network *NetworkBuilder::BuildFromString(const StaticShape &input_shape, const char **str) {
89 SkipWhitespace(str);
90 char code_ch = **str;
91 if (code_ch == '[') {
92 return ParseSeries(input_shape, nullptr, str);
93 }
94 if (input_shape.depth() == 0) {
95 // There must be an input at this point.
96 return ParseInput(str);
97 }
98 switch (code_ch) {
99 case '(':
100 return ParseParallel(input_shape, str);
101 case 'R':
102 return ParseR(input_shape, str);
103 case 'S':
104 return ParseS(input_shape, str);
105 case 'C':
106 return ParseC(input_shape, str);
107 case 'M':
108 return ParseM(input_shape, str);
109 case 'L':
110 return ParseLSTM(input_shape, str);
111 case 'F':
112 return ParseFullyConnected(input_shape, str);
113 case 'O':
114 return ParseOutput(input_shape, str);
115 default:
116 tprintf("Invalid network spec:%s\n", *str);
117 return nullptr;
118 }
119 return nullptr;
120}
121
122// Parses an input specification and returns the result, which may include a
123// series.
124Network *NetworkBuilder::ParseInput(const char **str) {
125 // There must be an input at this point.
126 int length = 0;
127 int batch, height, width, depth;
128 int num_converted = sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
129 StaticShape shape;
130 shape.SetShape(batch, height, width, depth);
131 // num_converted may or may not include the length.
132 if (num_converted != 4 && num_converted != 5) {
133 tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
134 return nullptr;
135 }
136 *str += length;
137 auto *input = new Input("Input", shape);
138 // We want to allow [<input>rest of net... or <input>[rest of net... so we
139 // have to check explicitly for '[' here.
140 SkipWhitespace(str);
141 if (**str == '[') {
142 return ParseSeries(shape, input, str);
143 }
144 return input;
145}
146
147// Parses a sequential series of networks, defined by [<net><net>...].
148Network *NetworkBuilder::ParseSeries(const StaticShape &input_shape, Input *input_layer,
149 const char **str) {
150 StaticShape shape = input_shape;
151 auto *series = new Series("Series");
152 ++*str;
153 if (input_layer != nullptr) {
154 series->AddToStack(input_layer);
155 shape = input_layer->OutputShape(shape);
156 }
157 Network *network = nullptr;
158 while (**str != '\0' && **str != ']' && (network = BuildFromString(shape, str)) != nullptr) {
159 shape = network->OutputShape(shape);
160 series->AddToStack(network);
161 }
162 if (**str != ']') {
163 tprintf("Missing ] at end of [Series]!\n");
164 delete series;
165 return nullptr;
166 }
167 ++*str;
168 return series;
169}
170
171// Parses a parallel set of networks, defined by (<net><net>...).
172Network *NetworkBuilder::ParseParallel(const StaticShape &input_shape, const char **str) {
173 auto *parallel = new Parallel("Parallel", NT_PARALLEL);
174 ++*str;
175 Network *network = nullptr;
176 while (**str != '\0' && **str != ')' &&
177 (network = BuildFromString(input_shape, str)) != nullptr) {
178 parallel->AddToStack(network);
179 }
180 if (**str != ')') {
181 tprintf("Missing ) at end of (Parallel)!\n");
182 delete parallel;
183 return nullptr;
184 }
185 ++*str;
186 return parallel;
187}
188
189// Parses a network that begins with 'R'.
190Network *NetworkBuilder::ParseR(const StaticShape &input_shape, const char **str) {
191 char dir = (*str)[1];
192 if (dir == 'x' || dir == 'y') {
193 std::string name = "Reverse";
194 name += dir;
195 *str += 2;
196 Network *network = BuildFromString(input_shape, str);
197 if (network == nullptr) {
198 return nullptr;
199 }
200 auto *rev = new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
201 rev->SetNetwork(network);
202 return rev;
203 }
204 char *end;
205 int replicas = strtol(*str + 1, &end, 10);
206 *str = end;
207 if (replicas <= 0) {
208 tprintf("Invalid R spec!:%s\n", end);
209 return nullptr;
210 }
211 auto *parallel = new Parallel("Replicated", NT_REPLICATED);
212 const char *str_copy = *str;
213 for (int i = 0; i < replicas; ++i) {
214 str_copy = *str;
215 Network *network = BuildFromString(input_shape, &str_copy);
216 if (network == nullptr) {
217 tprintf("Invalid replicated network!\n");
218 delete parallel;
219 return nullptr;
220 }
221 parallel->AddToStack(network);
222 }
223 *str = str_copy;
224 return parallel;
225}
226
227// Parses a network that begins with 'S'.
228Network *NetworkBuilder::ParseS(const StaticShape &input_shape, const char **str) {
229 char *end;
230 int y = strtol(*str + 1, &end, 10);
231 *str = end;
232 if (**str == ',') {
233 int x = strtol(*str + 1, &end, 10);
234 *str = end;
235 if (y <= 0 || x <= 0) {
236 tprintf("Invalid S spec!:%s\n", *str);
237 return nullptr;
238 }
239 return new Reconfig("Reconfig", input_shape.depth(), x, y);
240 } else if (**str == '(') {
241 // TODO(rays) Add Generic reshape.
242 tprintf("Generic reshape not yet implemented!!\n");
243 return nullptr;
244 }
245 tprintf("Invalid S spec!:%s\n", *str);
246 return nullptr;
247}
248
249// Helper returns the fully-connected type for the character code.
250static NetworkType NonLinearity(char func) {
251 switch (func) {
252 case 's':
253 return NT_LOGISTIC;
254 case 't':
255 return NT_TANH;
256 case 'r':
257 return NT_RELU;
258 case 'l':
259 return NT_LINEAR;
260 case 'm':
261 return NT_SOFTMAX;
262 case 'p':
263 return NT_POSCLIP;
264 case 'n':
265 return NT_SYMCLIP;
266 default:
267 return NT_NONE;
268 }
269}
270
271// Parses a network that begins with 'C'.
272Network *NetworkBuilder::ParseC(const StaticShape &input_shape, const char **str) {
273 NetworkType type = NonLinearity((*str)[1]);
274 if (type == NT_NONE) {
275 tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
276 return nullptr;
277 }
278 int y = 0, x = 0, d = 0;
279 char *end;
280 if ((y = strtol(*str + 2, &end, 10)) <= 0 || *end != ',' ||
281 (x = strtol(end + 1, &end, 10)) <= 0 || *end != ',' || (d = strtol(end + 1, &end, 10)) <= 0) {
282 tprintf("Invalid C spec!:%s\n", end);
283 return nullptr;
284 }
285 *str = end;
286 if (x == 1 && y == 1) {
287 // No actual convolution. Just a FullyConnected on the current depth, to
288 // be slid over all batch,y,x.
289 return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
290 }
291 auto *series = new Series("ConvSeries");
292 auto *convolve = new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
293 series->AddToStack(convolve);
294 StaticShape fc_input = convolve->OutputShape(input_shape);
295 series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
296 return series;
297}
298
299// Parses a network that begins with 'M'.
300Network *NetworkBuilder::ParseM(const StaticShape &input_shape, const char **str) {
301 int y = 0, x = 0;
302 char *end;
303 if ((*str)[1] != 'p' || (y = strtol(*str + 2, &end, 10)) <= 0 || *end != ',' ||
304 (x = strtol(end + 1, &end, 10)) <= 0) {
305 tprintf("Invalid Mp spec!:%s\n", *str);
306 return nullptr;
307 }
308 *str = end;
309 return new Maxpool("Maxpool", input_shape.depth(), x, y);
310}
311
312// Parses an LSTM network, either individual, bi- or quad-directional.
313Network *NetworkBuilder::ParseLSTM(const StaticShape &input_shape, const char **str) {
314 bool two_d = false;
316 const char *spec_start = *str;
317 int chars_consumed = 1;
318 int num_outputs = 0;
319 char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
320 if (key == 'S') {
322 num_outputs = num_softmax_outputs_;
323 ++chars_consumed;
324 } else if (key == 'E') {
326 num_outputs = num_softmax_outputs_;
327 ++chars_consumed;
328 } else if (key == '2' &&
329 (((*str)[2] == 'x' && (*str)[3] == 'y') || ((*str)[2] == 'y' && (*str)[3] == 'x'))) {
330 chars_consumed = 4;
331 dim = (*str)[3];
332 two_d = true;
333 } else if (key == 'f' || key == 'r' || key == 'b') {
334 dir = key;
335 dim = (*str)[2];
336 if (dim != 'x' && dim != 'y') {
337 tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
338 return nullptr;
339 }
340 chars_consumed = 3;
341 if ((*str)[chars_consumed] == 's') {
342 ++chars_consumed;
344 }
345 } else {
346 tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
347 return nullptr;
348 }
349 char *end;
350 int num_states = strtol(*str + chars_consumed, &end, 10);
351 if (num_states <= 0) {
352 tprintf("Invalid number of states in L Spec!:%s\n", *str);
353 return nullptr;
354 }
355 *str = end;
356 Network *lstm = nullptr;
357 if (two_d) {
358 lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
359 } else {
360 if (num_outputs == 0) {
361 num_outputs = num_states;
362 }
363 std::string name(spec_start, *str - spec_start);
364 lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type);
365 if (dir != 'f') {
366 auto *rev = new Reversed("RevLSTM", NT_XREVERSED);
367 rev->SetNetwork(lstm);
368 lstm = rev;
369 }
370 if (dir == 'b') {
371 name += "LTR";
372 auto *parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
373 parallel->AddToStack(
374 new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type));
375 parallel->AddToStack(lstm);
376 lstm = parallel;
377 }
378 }
379 if (dim == 'y') {
380 auto *rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
381 rev->SetNetwork(lstm);
382 lstm = rev;
383 }
384 return lstm;
385}
386
387// Builds a set of 4 lstms with x and y reversal, running in true parallel.
388Network *NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
389 auto *parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
390 parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM));
391 auto *rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
392 rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states, true, NT_LSTM));
393 parallel->AddToStack(rev);
394 rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
395 rev->SetNetwork(new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
396 auto *rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
397 rev2->SetNetwork(rev);
398 parallel->AddToStack(rev2);
399 rev = new Reversed("L2DXRevY", NT_YREVERSED);
400 rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM));
401 parallel->AddToStack(rev);
402 return parallel;
403}
404
405// Helper builds a truly (0-d) fully connected layer of the given type.
406static Network *BuildFullyConnected(const StaticShape &input_shape, NetworkType type,
407 const std::string &name, int depth) {
408 if (input_shape.height() == 0 || input_shape.width() == 0) {
409 tprintf("Fully connected requires positive height and width, had %d,%d\n", input_shape.height(),
410 input_shape.width());
411 return nullptr;
412 }
413 int input_size = input_shape.height() * input_shape.width();
414 int input_depth = input_size * input_shape.depth();
415 Network *fc = new FullyConnected(name, input_depth, depth, type);
416 if (input_size > 1) {
417 auto *series = new Series("FCSeries");
418 series->AddToStack(
419 new Reconfig("FCReconfig", input_shape.depth(), input_shape.width(), input_shape.height()));
420 series->AddToStack(fc);
421 fc = series;
422 }
423 return fc;
424}
425
426// Parses a Fully connected network.
427Network *NetworkBuilder::ParseFullyConnected(const StaticShape &input_shape, const char **str) {
428 const char *spec_start = *str;
429 NetworkType type = NonLinearity((*str)[1]);
430 if (type == NT_NONE) {
431 tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
432 return nullptr;
433 }
434 char *end;
435 int depth = strtol(*str + 2, &end, 10);
436 if (depth <= 0) {
437 tprintf("Invalid F spec!:%s\n", *str);
438 return nullptr;
439 }
440 *str = end;
441 std::string name(spec_start, *str - spec_start);
442 return BuildFullyConnected(input_shape, type, name, depth);
443}
444
445// Parses an Output spec.
446Network *NetworkBuilder::ParseOutput(const StaticShape &input_shape, const char **str) {
447 char dims_ch = (*str)[1];
448 if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
449 tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
450 return nullptr;
451 }
452 char type_ch = (*str)[2];
453 if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
454 tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
455 return nullptr;
456 }
457 char *end;
458 int depth = strtol(*str + 3, &end, 10);
459 if (depth != num_softmax_outputs_) {
460 tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
461 num_softmax_outputs_);
462 depth = num_softmax_outputs_;
463 }
464 *str = end;
466 if (type_ch == 'l') {
468 } else if (type_ch == 's') {
470 }
471 if (dims_ch == '0') {
472 // Same as standard fully connected.
473 return BuildFullyConnected(input_shape, type, "Output", depth);
474 } else if (dims_ch == '2') {
475 // We don't care if x and/or y are variable.
476 return new FullyConnected("Output2d", input_shape.depth(), depth, type);
477 }
478 // For 1-d y has to be fixed, and if not 1, moved to depth.
479 if (input_shape.height() == 0) {
480 tprintf("Fully connected requires fixed height!\n");
481 return nullptr;
482 }
483 int input_size = input_shape.height();
484 int input_depth = input_size * input_shape.depth();
485 Network *fc = new FullyConnected("Output", input_depth, depth, type);
486 if (input_size > 1) {
487 auto *series = new Series("FCSeries");
488 series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1, input_shape.height()));
489 series->AddToStack(fc);
490 fc = series;
491 }
492 return fc;
493}
494
495} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:54
const double y
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
NetworkType
Definition: network.h:41
@ NT_LINEAR
Definition: network.h:65
@ NT_RELU
Definition: network.h:64
@ NT_XREVERSED
Definition: network.h:54
@ NT_LSTM
Definition: network.h:58
@ NT_SOFTMAX
Definition: network.h:66
@ NT_NONE
Definition: network.h:42
@ NT_LOGISTIC
Definition: network.h:60
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:74
@ NT_PARALLEL
Definition: network.h:47
@ NT_SYMCLIP
Definition: network.h:62
@ NT_PAR_2D_LSTM
Definition: network.h:51
@ NT_LSTM_SUMMARY
Definition: network.h:59
@ NT_YREVERSED
Definition: network.h:55
@ NT_POSCLIP
Definition: network.h:61
@ NT_LSTM_SOFTMAX
Definition: network.h:73
@ NT_XYTRANSPOSE
Definition: network.h:56
@ NT_SERIES
Definition: network.h:52
@ NT_SOFTMAX_NO_CTC
Definition: network.h:67
@ NT_TANH
Definition: network.h:63
@ NT_PAR_RL_LSTM
Definition: network.h:49
@ NT_REPLICATED
Definition: network.h:48
type
Definition: upload.py:458
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:131
NetworkType type() const
Definition: network.h:110
TESS_API void AppendSeries(Network *src)
Definition: series.cpp:192
TESS_API void SplitAt(unsigned last_start, Series **start, Series **end)
Definition: series.cpp:163
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:34
void CacheXScaleFactor(int factor) override
Definition: series.cpp:100
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:71
static bool InitNetwork(int num_outputs, const char *network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Network * BuildFromString(const StaticShape &input_shape, const char **str)