Esempio n. 1
0
        /// <summary>
        /// Given (symbolic) log-domain potentials, construct the graph for forward inference in a chain CRF.
        /// </summary>
        /// <param name="obs_potentials">(n_steps, n_classes) Axes correspond to time and the value of the discrete label variable
        /// This is the energy assigned to a configuration (so higher energy = lower probability).</param>
        /// <param name="chain_potentials">(n_classes, n_classes, n_classes) Axes correspond to left label state, right label state, and the global label.
        /// Corresponds to the energy of a given pair of labels adjacent to one another (higher energy = lower probability).</param>
        /// <param name="viterbi">Perform MAP inference with the Viterbi algorithm rather than marginalizing the step-specific
        /// label variables, Instead, use the single most likely configuration.</param>
        /// <returns>(1-dimensional) The energy assigned for a given global label.
        /// This can be turned into a log probability by subtracting logsumexp(energy).</returns>
        public static Tensor <float> Forward(Tensor <float> obs_potentials, Tensor <float> chain_potentials, bool viterbi = false)
        {
            Func <Tensor <float>, Tensor <float>, Tensor <float> > inner_function = (obs, prior_result /*, chain_potentials*/) =>
            {
                prior_result = prior_result.DimShuffle(0, 'x', 1);
                obs          = obs.DimShuffle('x', 0, 'x');
                if (viterbi)
                {
                    return(T.Max((-prior_result - obs - chain_potentials), axis: 0));
                }
                else
                {
                    return(LogSumExp(-prior_result - obs - chain_potentials, axis: 0));
                }
            };

            Debug.Assert(obs_potentials.NDim == 2);
            Debug.Assert(chain_potentials.NDim == 3);
            var initial = (obs_potentials[0].DimShuffle(0, 'x') * T.OnesLike(chain_potentials[0]));
            var scanned = T.Scan(
                fn: inner_function,
                outputsInfo: initial,
                sequences: new[] { obs_potentials[XSlicer.From(1)] }
                //non_sequences: chain_potentials
                );

            if (viterbi)
            {
                return(-(T.Max(scanned[-1], axis: 0)));
            }
            else
            {
                return(-LogSumExp(scanned[-1], axis: 0));
            }
        }
Esempio n. 2
0
        public void TestShapeOfSlice()
        {
            var v = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, 13).As <float>(), "v");

            AssertArray.WriteTheSame(new[] { 3 }, v[XSlicer.Range(5, 8)].Shape);
            AssertArray.WriteTheSame(new[] { 3 }, v[XSlicer.Range(8, 5, -1)].Shape);
            AssertArray.WriteTheSame(new[] { 4 }, v[XSlicer.Range(3, 11, 2)].Shape);

            var M = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, 8, 22).As <float>(), "M");

            AssertArray.WriteTheSame(new[] { 5, 17 }, M[XSlicer.Range(6, 1, -1), XSlicer.From(5)].Shape);
        }