예제 #1
0
        public override float Forward(
            int[] targets,
            int targetIndex,
            Model.State state,
            float lr,
            bool backprop)
        {
            Debug.Assert(targetIndex >= 0);
            Debug.Assert(targetIndex < targets.Length);

            ComputeOutput(state);
            var target = targets[targetIndex];

            if (backprop)
            {
                var osz = wo_.Size(0);
                for (int i = 0; i < osz; i++)
                {
                    var label = (i == target) ? 1f : 0f;
                    var alpha = lr * (label - state.output[i]);
                    state.grad.AddRow(wo_, i, alpha);
                    wo_.AddVectorToRow(state.hidden.Data, i, alpha);
                }
            }
            return(-Log(state.output[target]));
        }
예제 #2
0
        public float BinaryLogistic(
            int target,
            Model.State state,
            bool labelIsPositive,
            float lr,
            bool backprop)
        {
            var score = Sigmoid(wo_.DotRow(state.hidden.Data, target));

            if (backprop)
            {
                var flabelIsPositive = (float)Convert.ToDouble(labelIsPositive);
                var alpha            = lr * (flabelIsPositive - score);
                state.grad.AddRow(wo_, target, alpha);
                wo_.AddVectorToRow(state.hidden.Data, target, alpha);
            }

            if (labelIsPositive)
            {
                return(-Log(score));
            }
            else
            {
                return(-Log(1f - score));
            }
        }
예제 #3
0
        public override void ComputeOutput(Model.State state)
        {
            Vector output = state.output;

            output.Mul(wo_, state.hidden);

            var max = output[0];
            var z   = 0f;
            var osz = output.Size();

            for (int i = 0; i < osz; i++)
            {
                max = Math.Max(output[i], max);
            }

            for (int i = 0; i < osz; i++)
            {
                output[i] = (float)Math.Exp(output[i] - max);
                z        += output[i];
            }

            for (int i = 0; i < osz; i++)
            {
                output[i] /= z;
            }
        }
예제 #4
0
 public override void Predict(
     int k,
     float threshold,
     Predictions heap,
     Model.State state)
 {
     DFS(k, threshold, 2 * osz_ - 2, 0f, heap, state.hidden.Data);
 }
예제 #5
0
 public virtual void Predict(
     int k,
     float threshold,
     Predictions heap,
     Model.State state)
 {
     ComputeOutput(state);
     FindKBest(k, threshold, heap, state.output.Data);
 }
예제 #6
0
        protected void TrainThread(int threadId)
        {
            var ifs = new FileStream(args_.input, FileMode.Open, FileAccess.Read);

            ifs.Flush();
            ifs.Seek(threadId * ifs.Length / args_.thread, SeekOrigin.Begin);

            var state = new Model.State(args_.dim, (int)output_.Size(0), threadId);

            var ntokens         = dict_.ntokens;
            var localTokenCount = 0L;
            var line            = new List <int>();
            var labels          = new List <int>();

            while (tokenCount_ < args_.epoch * ntokens)
            {
                var progress = (float)tokenCount_ / (args_.epoch * ntokens);
                var lr       = (float)args_.lr * (1f - progress);

                if (args_.model == ModelName.sup)
                {
                    localTokenCount += dict_.GetLine(ifs, line, labels);
                    Supervised(state, lr, line.ToArray(), labels.ToArray());
                }
                else if (args_.model == ModelName.cbow)
                {
                    localTokenCount += dict_.GetLine(ifs, line, state.rng);
                    Cbow(state, lr, line.ToArray());
                }
                else if (args_.model == ModelName.sg)
                {
                    localTokenCount += dict_.GetLine(ifs, line, state.rng);
                    Skipgram(state, lr, line.ToArray());
                }

                if (localTokenCount > args_.lrUpdateRate)
                {
                    tokenCount_    += localTokenCount;
                    localTokenCount = 0;

                    if (threadId == 0 && args_.verbose > 1)
                    {
                        loss_ = state.GetLoss();
                    }
                }
            }

            if (threadId == 0)
            {
                loss_ = state.GetLoss();
            }

            ifs.Close();
        }
예제 #7
0
        public override void ComputeOutput(Model.State state)
        {
            Vector output = state.output;

            output.Mul(wo_, state.hidden);
            var osz = output.Size();

            for (int i = 0; i < osz; i++)
            {
                output[i] = Sigmoid(output[i]);
            }
        }
예제 #8
0
 protected void Skipgram(Model.State state, float lr, int[] line)
 {
     for (int w = 0; w < line.Length; w++)
     {
         var boundary = state.rng.Next(1, args_.ws);
         var ngrams   = dict_.GetSubwords(line[w]);
         for (int c = -boundary; c <= boundary; c++)
         {
             if (c != 0 && w + c >= 0 && w + c < line.Length)
             {
                 model_.Update(ngrams, line, w + c, lr, state);
             }
         }
     }
 }
예제 #9
0
        public void Predict(int k, int[] words, Predictions predictions, float threshold = 0f)
        {
            if (words.Length == 0)
            {
                return;
            }

            var state = new Model.State(args_.dim, dict_.nlabels, 0);

            if (args_.model != ModelName.sup)
            {
                throw new ArgumentException("Model needs to be supervised for prediction!");
            }

            model_.Predict(words, k, threshold, predictions, state);
        }
예제 #10
0
        public override float Forward(
            int[] targets,
            int targetIndex,
            Model.State state,
            float lr,
            bool backprop)
        {
            var loss = 0f;
            var osz  = state.output.Size();

            for (int i = 0; i < osz; i++)
            {
                bool isMatch = Utils.Contains(targets, i);
                loss += BinaryLogistic(i, state, isMatch, lr, backprop);
            }

            return(loss);
        }
예제 #11
0
        public override float Forward(
            int[] targets,
            int targetIndex,
            Model.State state,
            float lr,
            bool backprop)
        {
            var loss   = 0f;
            var target = targets[targetIndex];

            var binaryCode = codes_[target];
            var pathToRoot = paths_[target];

            for (int i = 0; i < pathToRoot.Length; i++)
            {
                loss += BinaryLogistic(pathToRoot[i], state, binaryCode[i], lr, backprop);
            }
            return(loss);
        }
예제 #12
0
        public override float Forward(
            int[] targets,
            int targetIndex,
            Model.State state,
            float lr,
            bool backprop)
        {
            Debug.Assert(targetIndex >= 0);
            Debug.Assert(targetIndex < targets.Length);

            var target = targets[targetIndex];
            var loss   = BinaryLogistic(target, state, true, lr, backprop);

            for (int n = 0; n < neg_; n++)
            {
                var negativeTarget = getNegative(target, state.rng);
                loss += BinaryLogistic(negativeTarget, state, false, lr, backprop);
            }
            return(loss);
        }
예제 #13
0
        protected void Cbow(Model.State state, float lr, int[] line)
        {
            var bow = new List <int>();

            for (int w = 0; w < line.Length; w++)
            {
                var boundary = state.rng.Next(1, args_.ws);
                bow.Clear();

                for (int c = -boundary; c <= boundary; c++)
                {
                    if (c != 0 && w + c >= 0 && w + c < line.Length)
                    {
                        var ngrams = dict_.GetSubwords(line[w + c]);
                        bow.AddRange(ngrams);
                    }
                }
                model_.Update(bow.ToArray(), line, w, lr, state);
            }
        }
예제 #14
0
        protected void Supervised(
            Model.State state,
            float lr,
            int[] line,
            int[] labels)
        {
            if (labels.Length == 0 || line.Length == 0)
            {
                return;
            }

            if (args_.loss == LossName.ova)
            {
                model_.Update(line, labels, Model.kAllLabelsAsTarget, lr, state);
            }

            else
            {
                var i = state.rng.Next(0, labels.Length - 1);
                model_.Update(line, labels, i, lr, state);
            }
        }
예제 #15
0
 public abstract void ComputeOutput(Model.State state);
예제 #16
0
 public abstract float Forward(
     int[] targets,
     int targetIndex,
     Model.State state,
     float lr,
     bool backprop);