46 int net_flags,
float weight_range,
TRand *randomizer,
49 Series *bottom_series =
nullptr;
51 if (append_index >= 0) {
54 auto *series =
static_cast<Series *
>(*network);
55 Series *top_series =
nullptr;
56 series->
SplitAt(append_index, &bottom_series, &top_series);
57 if (bottom_series ==
nullptr || top_series ==
nullptr) {
58 tprintf(
"Yikes! Splitting current network failed!!\n");
61 input_shape = bottom_series->
OutputShape(input_shape);
65 if (*network ==
nullptr) {
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(
false);
71 if (bottom_series !=
nullptr) {
73 *network = bottom_series;
80static void SkipWhitespace(
const char **str) {
81 while (**str ==
' ' || **str ==
'\t' || **str ==
'\n') {
92 return ParseSeries(input_shape,
nullptr, str);
94 if (input_shape.
depth() == 0) {
96 return ParseInput(str);
100 return ParseParallel(input_shape, str);
102 return ParseR(input_shape, str);
104 return ParseS(input_shape, str);
106 return ParseC(input_shape, str);
108 return ParseM(input_shape, str);
110 return ParseLSTM(input_shape, str);
112 return ParseFullyConnected(input_shape, str);
114 return ParseOutput(input_shape, str);
116 tprintf(
"Invalid network spec:%s\n", *str);
124Network *NetworkBuilder::ParseInput(
const char **str) {
127 int batch, height, width, depth;
128 int num_converted = sscanf(*str,
"%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
130 shape.
SetShape(batch, height, width, depth);
132 if (num_converted != 4 && num_converted != 5) {
133 tprintf(
"Must specify an input layer as the first layer, not %s!!\n", *str);
137 auto *input =
new Input(
"Input", shape);
142 return ParseSeries(shape, input, str);
148Network *NetworkBuilder::ParseSeries(
const StaticShape &input_shape, Input *input_layer,
150 StaticShape shape = input_shape;
151 auto *series =
new Series(
"Series");
153 if (input_layer !=
nullptr) {
154 series->AddToStack(input_layer);
155 shape = input_layer->OutputShape(shape);
157 Network *network =
nullptr;
158 while (**str !=
'\0' && **str !=
']' && (network =
BuildFromString(shape, str)) !=
nullptr) {
159 shape = network->OutputShape(shape);
160 series->AddToStack(network);
163 tprintf(
"Missing ] at end of [Series]!\n");
172Network *NetworkBuilder::ParseParallel(
const StaticShape &input_shape,
const char **str) {
173 auto *parallel =
new Parallel(
"Parallel",
NT_PARALLEL);
175 Network *network =
nullptr;
176 while (**str !=
'\0' && **str !=
')' &&
178 parallel->AddToStack(network);
181 tprintf(
"Missing ) at end of (Parallel)!\n");
190Network *NetworkBuilder::ParseR(
const StaticShape &input_shape,
const char **str) {
191 char dir = (*str)[1];
192 if (dir ==
'x' || dir ==
'y') {
193 std::string name =
"Reverse";
197 if (network ==
nullptr) {
201 rev->SetNetwork(network);
205 int replicas = strtol(*str + 1, &end, 10);
208 tprintf(
"Invalid R spec!:%s\n", end);
212 const char *str_copy = *str;
213 for (
int i = 0;
i < replicas; ++
i) {
216 if (network ==
nullptr) {
217 tprintf(
"Invalid replicated network!\n");
221 parallel->AddToStack(network);
228Network *NetworkBuilder::ParseS(
const StaticShape &input_shape,
const char **str) {
230 int y = strtol(*str + 1, &end, 10);
233 int x = strtol(*str + 1, &end, 10);
235 if (
y <= 0 ||
x <= 0) {
236 tprintf(
"Invalid S spec!:%s\n", *str);
239 return new Reconfig(
"Reconfig", input_shape.depth(),
x,
y);
240 }
else if (**str ==
'(') {
242 tprintf(
"Generic reshape not yet implemented!!\n");
245 tprintf(
"Invalid S spec!:%s\n", *str);
272Network *NetworkBuilder::ParseC(
const StaticShape &input_shape,
const char **str) {
275 tprintf(
"Invalid nonlinearity on C-spec!: %s\n", *str);
278 int y = 0,
x = 0, d = 0;
280 if ((
y = strtol(*str + 2, &end, 10)) <= 0 || *end !=
',' ||
281 (
x = strtol(end + 1, &end, 10)) <= 0 || *end !=
',' || (d = strtol(end + 1, &end, 10)) <= 0) {
282 tprintf(
"Invalid C spec!:%s\n", end);
286 if (
x == 1 &&
y == 1) {
289 return new FullyConnected(
"Conv1x1", input_shape.depth(), d,
type);
291 auto *series =
new Series(
"ConvSeries");
292 auto *convolve =
new Convolve(
"Convolve", input_shape.depth(),
x / 2,
y / 2);
293 series->AddToStack(convolve);
294 StaticShape fc_input = convolve->OutputShape(input_shape);
295 series->AddToStack(
new FullyConnected(
"ConvNL", fc_input.depth(), d,
type));
300Network *NetworkBuilder::ParseM(
const StaticShape &input_shape,
const char **str) {
303 if ((*str)[1] !=
'p' || (
y = strtol(*str + 2, &end, 10)) <= 0 || *end !=
',' ||
304 (
x = strtol(end + 1, &end, 10)) <= 0) {
305 tprintf(
"Invalid Mp spec!:%s\n", *str);
309 return new Maxpool(
"Maxpool", input_shape.depth(),
x,
y);
313Network *NetworkBuilder::ParseLSTM(
const StaticShape &input_shape,
const char **str) {
316 const char *spec_start = *str;
317 int chars_consumed = 1;
319 char key = (*str)[chars_consumed], dir =
'f', dim =
'x';
322 num_outputs = num_softmax_outputs_;
324 }
else if (key ==
'E') {
326 num_outputs = num_softmax_outputs_;
328 }
else if (key ==
'2' &&
329 (((*str)[2] ==
'x' && (*str)[3] ==
'y') || ((*str)[2] ==
'y' && (*str)[3] ==
'x'))) {
333 }
else if (key ==
'f' || key ==
'r' || key ==
'b') {
336 if (dim !=
'x' && dim !=
'y') {
337 tprintf(
"Invalid dimension (x|y) in L Spec!:%s\n", *str);
341 if ((*str)[chars_consumed] ==
's') {
346 tprintf(
"Invalid direction (f|r|b) in L Spec!:%s\n", *str);
350 int num_states = strtol(*str + chars_consumed, &end, 10);
351 if (num_states <= 0) {
352 tprintf(
"Invalid number of states in L Spec!:%s\n", *str);
356 Network *lstm =
nullptr;
358 lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
360 if (num_outputs == 0) {
361 num_outputs = num_states;
363 std::string name(spec_start, *str - spec_start);
364 lstm =
new LSTM(name, input_shape.depth(), num_states, num_outputs,
false,
type);
367 rev->SetNetwork(lstm);
373 parallel->AddToStack(
374 new LSTM(name, input_shape.depth(), num_states, num_outputs,
false,
type));
375 parallel->AddToStack(lstm);
381 rev->SetNetwork(lstm);
388Network *NetworkBuilder::BuildLSTMXYQuad(
int num_inputs,
int num_states) {
390 parallel->AddToStack(
new LSTM(
"L2DLTRDown", num_inputs, num_states, num_states,
true,
NT_LSTM));
392 rev->SetNetwork(
new LSTM(
"L2DRTLDown", num_inputs, num_states, num_states,
true,
NT_LSTM));
393 parallel->AddToStack(rev);
395 rev->SetNetwork(
new LSTM(
"L2DRTLUp", num_inputs, num_states, num_states,
true,
NT_LSTM));
397 rev2->SetNetwork(rev);
398 parallel->AddToStack(rev2);
400 rev->SetNetwork(
new LSTM(
"L2DLTRDown", num_inputs, num_states, num_states,
true,
NT_LSTM));
401 parallel->AddToStack(rev);
406static Network *BuildFullyConnected(
const StaticShape &input_shape,
NetworkType type,
407 const std::string &name,
int depth) {
408 if (input_shape.height() == 0 || input_shape.width() == 0) {
409 tprintf(
"Fully connected requires positive height and width, had %d,%d\n", input_shape.height(),
410 input_shape.width());
413 int input_size = input_shape.height() * input_shape.width();
414 int input_depth = input_size * input_shape.depth();
415 Network *fc =
new FullyConnected(name, input_depth, depth,
type);
416 if (input_size > 1) {
417 auto *series =
new Series(
"FCSeries");
419 new Reconfig(
"FCReconfig", input_shape.depth(), input_shape.width(), input_shape.height()));
420 series->AddToStack(fc);
427Network *NetworkBuilder::ParseFullyConnected(
const StaticShape &input_shape,
const char **str) {
428 const char *spec_start = *str;
431 tprintf(
"Invalid nonlinearity on F-spec!: %s\n", *str);
435 int depth = strtol(*str + 2, &end, 10);
437 tprintf(
"Invalid F spec!:%s\n", *str);
441 std::string name(spec_start, *str - spec_start);
442 return BuildFullyConnected(input_shape,
type, name, depth);
446Network *NetworkBuilder::ParseOutput(
const StaticShape &input_shape,
const char **str) {
447 char dims_ch = (*str)[1];
448 if (dims_ch !=
'0' && dims_ch !=
'1' && dims_ch !=
'2') {
449 tprintf(
"Invalid dims (2|1|0) in output spec!:%s\n", *str);
452 char type_ch = (*str)[2];
453 if (type_ch !=
'l' && type_ch !=
's' && type_ch !=
'c') {
454 tprintf(
"Invalid output type (l|s|c) in output spec!:%s\n", *str);
458 int depth = strtol(*str + 3, &end, 10);
459 if (depth != num_softmax_outputs_) {
460 tprintf(
"Warning: given outputs %d not equal to unicharset of %d.\n", depth,
461 num_softmax_outputs_);
462 depth = num_softmax_outputs_;
466 if (type_ch ==
'l') {
468 }
else if (type_ch ==
's') {
471 if (dims_ch ==
'0') {
473 return BuildFullyConnected(input_shape,
type,
"Output", depth);
474 }
else if (dims_ch ==
'2') {
476 return new FullyConnected(
"Output2d", input_shape.depth(), depth,
type);
479 if (input_shape.height() == 0) {
480 tprintf(
"Fully connected requires fixed height!\n");
483 int input_size = input_shape.height();
484 int input_depth = input_size * input_shape.depth();
485 Network *fc =
new FullyConnected(
"Output", input_depth, depth,
type);
486 if (input_size > 1) {
487 auto *series =
new Series(
"FCSeries");
488 series->AddToStack(
new Reconfig(
"FCReconfig", input_shape.depth(), 1, input_shape.height()));
489 series->AddToStack(fc);
void tprintf(const char *format,...)
@ NT_LSTM_SOFTMAX_ENCODED
virtual void SetNetworkFlags(uint32_t flags)
TESS_API void AppendSeries(Network *src)
TESS_API void SplitAt(unsigned last_start, Series **start, Series **end)
StaticShape OutputShape(const StaticShape &input_shape) const override
void CacheXScaleFactor(int factor) override
void SetShape(int batch, int height, int width, int depth)
static bool InitNetwork(int num_outputs, const char *network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Network * BuildFromString(const StaticShape &input_shape, const char **str)