70 {
71
73 UNICHARSET unicharset;
74 ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(),
false));
76 std::vector<std::string> words;
78 words, words, words, false, nullptr, nullptr));
79 std::string model_path =
file::JoinPath(FLAGS_test_tmpdir, model_name);
80 std::string checkpoint_path = model_path + "_checkpoint";
81 trainer_ = std::make_unique<LSTMTrainer>(model_path.c_str(), checkpoint_path.c_str(), 0, 0);
84 int net_mode = adam ?
NF_ADAM : 0;
85
86
87 if (adam) {
88 learning_rate *= 20.0f;
89 }
90 if (layer_specific) {
92 }
94 trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, learning_rate, 0.9, 0.999));
95 std::vector<std::string> filenames;
98 LOG(
INFO) <<
"Setup network:" << model_name <<
"\n";
99 }
#define EXPECT_EQ(val1, val2)
#define EXPECT_TRUE(condition)
#define ASSERT_TRUE(condition)
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
static std::string JoinPath(const std::string &s1, const std::string &s2)
std::unique_ptr< LSTMTrainer > trainer_
std::string TestDataNameToPath(const std::string &name)