41 :
Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
93 no_ = code_map.size();
125 int width = input.
Width();
136 temp_lines[i].Init(
no_, scratch);
137 curr_input[i].Init(
ni_, scratch);
140 #pragma omp parallel for num_threads(kNumThreads) 141 for (
int t = 0; t < width; ++t) {
143 int thread_id = omp_get_thread_num();
145 for (
int t = 0; t < width; ++t) {
149 double* temp_line = temp_lines[thread_id];
150 const double* d_input = NULL;
151 const inT8* i_input = NULL;
153 i_input = input.
i(t);
156 d_input = curr_input[thread_id];
192 int t,
double* output_line) {
201 FuncInplace<GFunc>(
no_, output_line);
203 FuncInplace<FFunc>(
no_, output_line);
205 FuncInplace<ClipFFunc>(
no_, output_line);
207 FuncInplace<ClipGFunc>(
no_, output_line);
209 FuncInplace<Relu>(
no_, output_line);
213 ASSERT_HOST(
"Invalid fully-connected type!" == NULL);
226 for (
int i = 0; i <
kNumThreads; ++i) errors[i].Init(
no_, scratch);
230 for (
int i = 0; i <
kNumThreads; ++i) temp_backprops[i].Init(
ni_, scratch);
232 int width = fwd_deltas.
Width();
234 errors_t.
Init(
no_, width, scratch);
236 #pragma omp parallel for num_threads(kNumThreads) 237 for (
int t = 0; t < width; ++t) {
238 int thread_id = omp_get_thread_num();
240 for (
int t = 0; t < width; ++t) {
243 double* backprop = NULL;
245 double* curr_errors = errors[thread_id];
247 if (backprop != NULL) {
256 back_deltas->
Print(10);
279 fwd_deltas.ReadTimeStep(t, curr_errors);
281 ASSERT_HOST(
"Invalid fully-connected type!" == NULL);
297 float adam_beta,
int num_samples) {
305 double* changed)
const {
virtual void SetEnableTraining(TrainingState state)
virtual int InitWeights(float range, TRand *randomizer)
void SoftmaxInPlace(int n, T *inout)
virtual bool Serialize(TFile *fp) const
virtual ~FullyConnected()
void Resize(const NetworkIO &src, int num_features)
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
int RemapOutputs(const std::vector< int > &code_map)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void ResizeNoInit(int size1, int size2, int pad=0)
virtual void ConvertToInt()
virtual void SetRandomizer(TRand *randomizer)
void ZeroInvalidElements()
void DisplayForward(const NetworkIO &matrix)
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
bool TestFlag(NetworkFlags flag) const
void MatrixDotVector(const double *u, double *v) const
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void FinishBackward(const TransposedArray &errors_t)
void VectorDotMatrix(const double *u, double *v) const
void Init(int size1, int size2, NetworkScratch *scratch)
virtual bool Serialize(TFile *fp) const
void ReadTimeStep(int t, double *output) const
TransposedArray * get() const
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void set_loss_type(LossType value)
const char * string() const
void WriteTimeStep(int t, const double *input)
void Debug2D(const char *msg)
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void ResizeFloat(const NetworkIO &src, int num_features)
virtual bool DeSerialize(TFile *fp)
void DisplayBackward(const NetworkIO &matrix)
void Print(int num) const
void set_depth(int value)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
void WriteStrided(int t, const float *data)
bool DeSerialize(bool training, TFile *fp)
virtual void DebugWeights()
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
void init_to_size(int size, T t)
const TransposedArray * external_source_
TransposedArray source_t_
const inT8 * i(int t) const
bool Serialize(bool training, TFile *fp) const