19# include "config_auto.h"
43 :
Network(
type, name, ni, no), external_source_(nullptr), int_mode_(false) {}
96 no_ = code_map.size();
132 int width = input.
Width();
139 std::vector<NetworkScratch::FloatVec> temp_lines(
kNumThreads);
140 std::vector<NetworkScratch::FloatVec> curr_input(
kNumThreads);
146 temp_lines[
i].Init(ro, scratch);
147 curr_input[
i].Init(
ni_, scratch);
150# pragma omp parallel for num_threads(kNumThreads)
151 for (
int t = 0; t < width; ++t) {
153 int thread_id = omp_get_thread_num();
155 for (
int t = 0; t < width; ++t) {
159 TFloat *temp_line = temp_lines[thread_id];
166 output->WriteTimeStep(t, temp_line);
177 output->ZeroInvalidElements();
182#ifndef GRAPHICS_DISABLED
205 FuncInplace<GFunc>(
no_, output_line);
207 FuncInplace<FFunc>(
no_, output_line);
209 FuncInplace<ClipFFunc>(
no_, output_line);
211 FuncInplace<ClipGFunc>(
no_, output_line);
213 FuncInplace<Relu>(
no_, output_line);
217 ASSERT_HOST(
"Invalid fully-connected type!" ==
nullptr);
240#ifndef GRAPHICS_DISABLED
246 std::vector<NetworkScratch::FloatVec> errors(
kNumThreads);
248 errors[
i].Init(
no_, scratch);
250 std::vector<NetworkScratch::FloatVec> temp_backprops;
254 temp_backprops[
i].Init(
ni_, scratch);
257 int width = fwd_deltas.
Width();
259 errors_t.
Init(
no_, width, scratch);
261# pragma omp parallel for num_threads(kNumThreads)
262 for (
int t = 0; t < width; ++t) {
263 int thread_id = omp_get_thread_num();
265 for (
int t = 0; t < width; ++t) {
268 TFloat *backprop =
nullptr;
270 backprop = temp_backprops[thread_id];
272 TFloat *curr_errors = errors[thread_id];
274 if (backprop !=
nullptr) {
283 back_deltas->
Print(10);
305 ASSERT_HOST(
"Invalid fully-connected type!" ==
nullptr);
308 if (backprop !=
nullptr) {
void tprintf(const char *format,...)
void SoftmaxInPlace(int n, T *inout)
void ResizeNoInit(int size1, int size2, int pad=0)
int RoundOutputs(int size) const
static const IntSimdMatrix * intSimdMatrix
void ForwardTimeStep(int t, TFloat *output_line)
bool DeSerialize(TFile *fp) override
void DebugWeights() override
void FinishBackward(const TransposedArray &errors_t)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
void SetEnableTraining(TrainingState state) override
const TransposedArray * external_source_
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
int InitWeights(float range, TRand *randomizer) override
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
TransposedArray source_t_
void ConvertToInt() override
TESS_API FullyConnected(const std::string &name, int ni, int no, NetworkType type)
StaticShape OutputShape(const StaticShape &input_shape) const override
bool Serialize(TFile *fp) const override
void DisplayForward(const NetworkIO &matrix)
void DisplayBackward(const NetworkIO &matrix)
virtual bool Serialize(TFile *fp) const
bool TestFlag(NetworkFlags flag) const
virtual void SetRandomizer(TRand *randomizer)
void Resize(const NetworkIO &src, int num_features)
void FuncMultiply(const NetworkIO &v_io, int t, TFloat *product)
void ZeroInvalidElements()
void WriteTimeStep(int t, const TFloat *input)
void Print(int num) const
void ReadTimeStep(int t, TFloat *output) const
const int8_t * i(int t) const
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
TransposedArray * get() const
void Init(int size1, int size2, NetworkScratch *scratch)
void set_loss_type(LossType value)
void set_depth(int value)
void WriteStrided(int t, const float *data)
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
bool Serialize(bool training, TFile *fp) const
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
void Debug2D(const char *msg)
int RemapOutputs(const std::vector< int > &code_map)
void VectorDotMatrix(const TFloat *u, TFloat *v) const
void MatrixDotVector(const TFloat *u, TFloat *v) const
bool DeSerialize(bool training, TFile *fp)
void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const