18#ifdef INCLUDE_TENSORFLOW
22# include <allheaders.h>
26using tensorflow::Status;
27using tensorflow::Tensor;
28using tensorflow::TensorShape;
32TFNetwork::TFNetwork(
const char *name) : Network(
NT_TENSORFLOW, name, 0, 0) {}
34int TFNetwork::InitFromProtoStr(
const std::string &proto_str) {
35 if (!model_proto_.ParseFromString(proto_str))
37 return InitFromProto();
45 std::string proto_str;
46 model_proto_.SerializeToString(&proto_str);
48 std::vector<char> data(proto_str.size());
49 memcpy(&data[0], proto_str.data(), proto_str.size());
50 return fp->Serialize(data);
56 std::vector<char> data;
57 if (!fp->DeSerialize(data))
59 if (!model_proto_.ParseFromArray(&data[0], data.size())) {
62 return InitFromProto();
67void TFNetwork::Forward(
bool debug,
const NetworkIO &input,
const TransposedArray *input_transpose,
68 NetworkScratch *scratch, NetworkIO *
output) {
69 std::vector<std::pair<std::string, Tensor>> tf_inputs;
70 int depth = input_shape_.depth();
73 const StrideMap &stride_map = input.stride_map();
75 TensorShape shape{1, stride_map.Size(
FD_HEIGHT), stride_map.Size(
FD_WIDTH), depth};
76 Tensor input_tensor(tensorflow::DT_FLOAT, shape);
78 auto eigen_tensor = input_tensor.flat<
float>();
79 memcpy(eigen_tensor.data(), input.f(0), input.Width() * depth *
sizeof(input.f(0)[0]));
81 tf_inputs.emplace_back(model_proto_.image_input(), input_tensor);
88 if (!model_proto_.image_widths().empty()) {
89 TensorShape size_shape{1};
90 Tensor width_tensor(tensorflow::DT_INT64, size_shape);
91 auto eigen_wtensor = width_tensor.flat<tensorflow::int64>();
92 *eigen_wtensor.data() = stride_map.Size(
FD_WIDTH);
93 tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
95 if (!model_proto_.image_heights().empty()) {
96 TensorShape size_shape{1};
97 Tensor height_tensor(tensorflow::DT_INT64, size_shape);
98 auto eigen_htensor = height_tensor.flat<tensorflow::int64>();
99 *eigen_htensor.data() = stride_map.Size(
FD_HEIGHT);
100 tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
102 std::vector<std::string> target_layers = {model_proto_.output_layer()};
103 std::vector<Tensor> outputs;
104 Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
106 tprintf(
"session->Run failed:%s\n", s.error_message().c_str());
109 const Tensor &output_tensor = outputs[0];
112 int output_batch = output_tensor.shape().dim_size(0);
113 int output_steps = output_tensor.shape().dim_size(1);
114 int output_depth = output_tensor.shape().dim_size(2);
116 ASSERT_HOST(output_depth == output_shape_.depth());
117 output->Resize2d(
false, output_steps, output_depth);
118 auto eigen_output = output_tensor.flat<
float>();
119 memcpy(
output->f(0), eigen_output.data(), output_steps * output_depth *
sizeof(
output->f(0)[0]));
122int TFNetwork::InitFromProto() {
123 spec_ = model_proto_.spec();
124 input_shape_.SetShape(model_proto_.batch_size(), std::max(0, model_proto_.y_size()),
125 std::max(0, model_proto_.x_size()), model_proto_.depth());
126 output_shape_.SetShape(model_proto_.batch_size(), 1, 0, model_proto_.num_classes());
127 output_shape_.set_loss_type(model_proto_.using_ctc() ?
LT_CTC :
LT_SOFTMAX);
128 ni_ = input_shape_.height();
129 no_ = output_shape_.depth();
132 tensorflow::SessionOptions options;
133 session_.reset(NewSession(options));
134 Status s = session_->Create(model_proto_.graph());
136 return model_proto_.global_step();
137 tprintf(
"Session_->Create returned '%s'\n", s.error_message().c_str());
void tprintf(const char *format,...)
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
bool Serialize(FILE *fp, const std::vector< T > &data)