tesseract v5.3.3.20231005
input.cpp
Go to the documentation of this file.
1
2// File: input.cpp
3// Description: Input layer class for neural network implementations.
4// Author: Ray Smith
5//
6// (C) Copyright 2014, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#include "input.h"
19
20#include <allheaders.h>
21#include "imagedata.h"
22#include "pageres.h"
23#include "scrollview.h"
24
25namespace tesseract {
26
27// Max height for variable height inputs before scaling anyway.
28const int kMaxInputHeight = 48;
29
30Input::Input(const std::string &name, int ni, int no)
31 : Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}
32Input::Input(const std::string &name, const StaticShape &shape)
33 : Network(NT_INPUT, name, shape.height(), shape.depth()), shape_(shape), cached_x_scale_(1) {
34 if (shape.height() == 1) {
35 ni_ = shape.depth();
36 }
37}
38
39// Writes to the given file. Returns false in case of error.
40bool Input::Serialize(TFile *fp) const {
41 return Network::Serialize(fp) && shape_.Serialize(fp);
42}
43
44// Reads from the given file. Returns false in case of error.
46 return shape_.DeSerialize(fp);
47}
48
49// Returns an integer reduction factor that the network applies to the
50// time sequence. Assumes that any 2-d is already eliminated. Used for
51// scaling bounding boxes of truth data.
53 return 1;
54}
55
56// Provides the (minimum) x scale factor to the network (of interest only to
57// input units) so they can determine how to scale bounding boxes.
58void Input::CacheXScaleFactor(int factor) {
59 cached_x_scale_ = factor;
60}
61
62// Runs forward propagation of activations on the input line.
63// See Network for a detailed discussion of the arguments.
64void Input::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
65 NetworkScratch *scratch, NetworkIO *output) {
66 *output = input;
67}
68
69// Runs backward propagation of errors on the deltas line.
70// See NetworkCpp for a detailed discussion of the arguments.
71bool Input::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
72 NetworkIO *back_deltas) {
73 tprintf("Input::Backward should not be called!!\n");
74 return false;
75}
76
77// Creates and returns a Pix of appropriate size for the network from the
78// image_data. If non-null, *image_scale returns the image scale factor used.
79// Returns nullptr on error.
80/* static */
81Image Input::PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width,
82 TRand *randomizer, float *image_scale) {
83 // Note that NumInputs() is defined as input image height.
84 int target_height = network->NumInputs();
85 int width, height;
86 Image pix =
87 image_data.PreScale(target_height, kMaxInputHeight, image_scale, &width, &height, nullptr);
88 if (pix == nullptr) {
89 tprintf("Bad pix from ImageData!\n");
90 return nullptr;
91 }
92 if (width < min_width || height < min_width) {
93 tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width, height, min_width);
94 pix.destroy();
95 return nullptr;
96 }
97 return pix;
98}
99
100// Converts the given pix to a NetworkIO of height and depth appropriate to the
101// given StaticShape:
102// If depth == 3, convert to 24 bit color, otherwise normalized grey.
103// Scale to target height, if the shape's height is > 1, or its depth if the
104// height == 1. If height == 0 then no scaling.
105// NOTE: It isn't safe for multiple threads to call this on the same pix.
106/* static */
107void Input::PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer,
108 NetworkIO *input) {
109 bool color = shape.depth() == 3;
110 Image var_pix = pix;
111 int depth = pixGetDepth(var_pix);
112 Image normed_pix = nullptr;
113 // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
114 // colormap, so we just have to deal with depth conversion here.
115 if (color) {
116 // Force RGB.
117 if (depth == 32) {
118 normed_pix = var_pix.clone();
119 } else {
120 normed_pix = pixConvertTo32(var_pix);
121 }
122 } else {
123 // Convert non-8-bit images to 8 bit.
124 if (depth == 8) {
125 normed_pix = var_pix.clone();
126 } else {
127 normed_pix = pixConvertTo8(var_pix, false);
128 }
129 }
130 int height = pixGetHeight(normed_pix);
131 int target_height = shape.height();
132 if (target_height == 1) {
133 target_height = shape.depth();
134 }
135 if (target_height != 0 && target_height != height) {
136 // Get the scaled image.
137 float im_factor = static_cast<float>(target_height) / height;
138 Image scaled_pix = pixScale(normed_pix, im_factor, im_factor);
139 normed_pix.destroy();
140 normed_pix = scaled_pix;
141 }
142 input->FromPix(shape, normed_pix, randomizer);
143 normed_pix.destroy();
144}
145
146} // namespace tesseract.
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
const int kMaxInputHeight
Definition: input.cpp:28
@ NT_INPUT
Definition: network.h:43
Image clone() const
Definition: image.cpp:24
void destroy()
Definition: image.cpp:32
Image PreScale(int target_height, int max_height, float *scale_factor, int *scaled_width, int *scaled_height, std::vector< TBOX > *boxes) const
Definition: imagedata.cpp:215
TESS_API Input(const std::string &name, int ni, int no)
Definition: input.cpp:30
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: input.cpp:64
int XScaleFactor() const override
Definition: input.cpp:52
static Image PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:81
void CacheXScaleFactor(int factor) override
Definition: input.cpp:58
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: input.cpp:71
static void PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:107
bool Serialize(TFile *fp) const override
Definition: input.cpp:40
bool DeSerialize(TFile *fp) override
Definition: input.cpp:45
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158
int NumInputs() const
Definition: network.h:122
void FromPix(const StaticShape &shape, const Image pix, TRand *randomizer)
Definition: networkio.cpp:163
bool Serialize(TFile *fp) const
Definition: static_shape.h:91
bool DeSerialize(TFile *fp)
Definition: static_shape.h:83