tesseract v5.3.3.20231005
static_shape.h
Go to the documentation of this file.
1
2// File: static_shape.h
3// Description: Defines the size of the 4-d tensor input/output from a network.
4// Author: Ray Smith
5// Created: Fri Oct 14 09:07:31 PST 2016
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
20#define TESSERACT_LSTM_STATIC_SHAPE_H_
21
22#include "serialis.h" // for TFile
23#include "tprintf.h" // for tprintf
24
25namespace tesseract {
26
27// Enum describing the loss function to apply during training and/or the
28// decoding method to apply at runtime.
30 LT_NONE, // Undefined.
31 LT_CTC, // Softmax with standard CTC for training/decoding.
32 LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
33 LT_LOGISTIC, // Logistic outputs with independent values.
34};
35
36// Simple class to hold the tensor shape that is known at network build time
37// and the LossType of the loss function.
39public:
40 StaticShape() : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
41 int batch() const {
42 return batch_;
43 }
44 void set_batch(int value) {
45 batch_ = value;
46 }
47 int height() const {
48 return height_;
49 }
50 void set_height(int value) {
51 height_ = value;
52 }
53 int width() const {
54 return width_;
55 }
56 void set_width(int value) {
57 width_ = value;
58 }
59 int depth() const {
60 return depth_;
61 }
62 void set_depth(int value) {
63 depth_ = value;
64 }
66 return loss_type_;
67 }
69 loss_type_ = value;
70 }
71 void SetShape(int batch, int height, int width, int depth) {
72 batch_ = batch;
73 height_ = height;
74 width_ = width;
75 depth_ = depth;
76 }
77
78 void Print() const {
79 tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_, height_, width_, depth_,
80 loss_type_);
81 }
82
83 bool DeSerialize(TFile *fp) {
84 int32_t tmp = LT_NONE;
85 bool result = fp->DeSerialize(&batch_) && fp->DeSerialize(&height_) &&
86 fp->DeSerialize(&width_) && fp->DeSerialize(&depth_) && fp->DeSerialize(&tmp);
87 loss_type_ = static_cast<LossType>(tmp);
88 return result;
89 }
90
91 bool Serialize(TFile *fp) const {
92 int32_t tmp = loss_type_;
93 return fp->Serialize(&batch_) && fp->Serialize(&height_) && fp->Serialize(&width_) &&
94 fp->Serialize(&depth_) && fp->Serialize(&tmp);
95 }
96
97private:
98 // Size of the 4-D tensor input/output to a network. A value of zero is
99 // allowed for all except depth_ and means to be determined at runtime, and
100 // regarded as variable.
101 // Number of elements in a batch, or number of frames in a video stream.
102 int32_t batch_;
103 // Height of the image.
104 int32_t height_;
105 // Width of the image.
106 int32_t width_;
107 // Depth of the image. (Number of "nodes").
108 int32_t depth_;
109 // How to train/interpret the output.
110 LossType loss_type_;
111};
112
113} // namespace tesseract
114
115#endif // TESSERACT_LSTM_STATIC_SHAPE_H_
int value
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
void set_batch(int value)
Definition: static_shape.h:44
void set_loss_type(LossType value)
Definition: static_shape.h:68
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:71
void set_depth(int value)
Definition: static_shape.h:62
LossType loss_type() const
Definition: static_shape.h:65
bool Serialize(TFile *fp) const
Definition: static_shape.h:91
void set_width(int value)
Definition: static_shape.h:56
void set_height(int value)
Definition: static_shape.h:50
bool DeSerialize(TFile *fp)
Definition: static_shape.h:83