public virtual void TestGetSummaryForInstance(GraphicalModel[] dataset, ConcatVector weights)
        {
            LogLikelihoodDifferentiableFunction fn = new LogLikelihoodDifferentiableFunction();

            foreach (GraphicalModel model in dataset)
            {
                double       goldLogLikelihood = LogLikelihood(model, (ConcatVector)weights);
                ConcatVector goldGradient      = DefinitionOfDerivative(model, (ConcatVector)weights);
                ConcatVector gradient          = new ConcatVector(0);
                double       logLikelihood     = fn.GetSummaryForInstance(model, (ConcatVector)weights, gradient);
                NUnit.Framework.Assert.AreEqual(logLikelihood, Math.Max(1.0e-3, goldLogLikelihood * 1.0e-2), goldLogLikelihood);
                // Our check for gradient similarity involves distance between endpoints of vectors, instead of elementwise
                // similarity, b/c it can be controlled as a percentage
                ConcatVector difference = goldGradient.DeepClone();
                difference.AddVectorInPlace(gradient, -1);
                double distance = Math.Sqrt(difference.DotProduct(difference));
                // The tolerance here is pretty large, since the gold gradient is computed approximately
                // 5% still tells us whether everything is working or not though
                if (distance > 5.0e-2)
                {
                    System.Console.Error.WriteLine("Definitional and calculated gradient differ!");
                    System.Console.Error.WriteLine("Definition approx: " + goldGradient);
                    System.Console.Error.WriteLine("Calculated: " + gradient);
                }
                NUnit.Framework.Assert.AreEqual(distance, 5.0e-2, 0.0);
            }
        }
Ejemplo n.º 2
0
        /// <summary>Construct a TableFactor for inference within a model.</summary>
        /// <remarks>
        /// Construct a TableFactor for inference within a model. This just copies the important bits from the model factor,
        /// and replaces the ConcatVectorTable with an internal datastructure that has done all the dotproducts with the
        /// weights out, and so stores only doubles.
        /// <p>
        /// Each element of the table is given by: t_i = exp(f_i*w)
        /// </remarks>
        /// <param name="weights">the vector to dot product with every element of the factor table</param>
        /// <param name="factor">the feature factor to be multiplied in</param>
        public TableFactor(ConcatVector weights, GraphicalModel.Factor factor)
            : base(factor.featuresTable.GetDimensions())
        {
            this.neighborIndices = factor.neigborIndices;
            // Calculate the factor residents by dot product with the weights
            // OPTIMIZATION:
            // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
            // we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
            IEnumerator <int[]> fastPassByReferenceIterator = factor.featuresTable.FastPassByReferenceIterator();

            int[] assignment = fastPassByReferenceIterator.Current;
            while (true)
            {
                SetAssignmentLogValue(assignment, factor.featuresTable.GetAssignmentValue(assignment).Get().DotProduct(weights));
                // This mutates the assignment[] array, rather than creating a new one
                if (fastPassByReferenceIterator.MoveNext())
                {
                    fastPassByReferenceIterator.Current;
                }
                else
                {
                    break;
                }
            }
        }
        /// <summary>
        /// Slowest possible way to calculate a derivative for a model: exhaustive definitional calculation, using the super
        /// slow logLikelihood function from this test suite.
        /// </summary>
        /// <param name="model">the model the get the derivative for</param>
        /// <param name="weights">the weights to get the derivative at</param>
        /// <returns>the derivative of the log likelihood with respect to the weights</returns>
        private ConcatVector DefinitionOfDerivative(GraphicalModel model, ConcatVector weights)
        {
            double       epsilon      = 1.0e-7;
            ConcatVector goldGradient = new ConcatVector(ConcatVecComponents);

            for (int i = 0; i < ConcatVecComponents; i++)
            {
                double[] component = new double[ConcatVecComponentLength];
                for (int j = 0; j < ConcatVecComponentLength; j++)
                {
                    // Create a unit vector pointing in the direction of this element of this component
                    ConcatVector unitVectorIJ = new ConcatVector(ConcatVecComponents);
                    unitVectorIJ.SetSparseComponent(i, j, 1.0);
                    // Create a +eps weight vector
                    ConcatVector weightsPlusEpsilon = weights.DeepClone();
                    weightsPlusEpsilon.AddVectorInPlace(unitVectorIJ, epsilon);
                    // Create a -eps weight vector
                    ConcatVector weightsMinusEpsilon = weights.DeepClone();
                    weightsMinusEpsilon.AddVectorInPlace(unitVectorIJ, -epsilon);
                    // Use the definition (f(x+eps) - f(x-eps))/(2*eps)
                    component[j] = (LogLikelihood(model, weightsPlusEpsilon) - LogLikelihood(model, weightsMinusEpsilon)) / (2 * epsilon);
                    // If we encounter an impossible assignment, logLikelihood will return negative infinity, which will
                    // screw with the definitional calculation
                    if (double.IsNaN(component[j]))
                    {
                        component[j] = 0.0;
                    }
                }
                goldGradient.SetDenseComponent(i, component);
            }
            return(goldGradient);
        }
 internal static ConcatVector[] MakeVectors(ConcatVectorBenchmark.ConcatVectorConstructionRecord[] records)
 {
     ConcatVector[] vectors = new ConcatVector[records.Length];
     for (int i = 0; i < records.Length; i++)
     {
         vectors[i] = records[i].Create();
     }
     return(vectors);
 }
        // this magic number was arrived at with relation to the CoNLL benchmark, and tinkering
        public override bool UpdateWeights(ConcatVector weights, ConcatVector gradient, double logLikelihood, AbstractBatchOptimizer.OptimizationState optimizationState, bool quiet)
        {
            BacktrackingAdaGradOptimizer.AdaGradOptimizationState s = (BacktrackingAdaGradOptimizer.AdaGradOptimizationState)optimizationState;
            double logLikelihoodChange = logLikelihood - s.lastLogLikelihood;

            if (logLikelihoodChange == 0)
            {
                if (!quiet)
                {
                    log.Info("\tlogLikelihood improvement = 0: quitting");
                }
                return(true);
            }
            else
            {
                // Check if we should backtrack
                if (logLikelihoodChange < 0)
                {
                    // If we should, move the weights back by half, and cut the lastDerivative by half
                    s.lastDerivative.MapInPlace(null);
                    weights.AddVectorInPlace(s.lastDerivative, -1.0);
                    if (!quiet)
                    {
                        log.Info("\tBACKTRACK...");
                    }
                    // if the lastDerivative norm falls below a threshold, it means we've converged
                    if (s.lastDerivative.DotProduct(s.lastDerivative) < 1.0e-10)
                    {
                        if (!quiet)
                        {
                            log.Info("\tBacktracking derivative norm " + s.lastDerivative.DotProduct(s.lastDerivative) + " < 1.0e-9: quitting");
                        }
                        return(true);
                    }
                }
                else
                {
                    // Apply AdaGrad
                    ConcatVector squared = gradient.DeepClone();
                    squared.MapInPlace(null);
                    s.adagradAccumulator.AddVectorInPlace(squared, 1.0);
                    ConcatVector sqrt = s.adagradAccumulator.DeepClone();
                    sqrt.MapInPlace(null);
                    gradient.ElementwiseProductInPlace(sqrt);
                    weights.AddVectorInPlace(gradient, 1.0);
                    // Setup for backtracking, in case necessary
                    s.lastDerivative    = gradient;
                    s.lastLogLikelihood = logLikelihood;
                    if (!quiet)
                    {
                        log.Info("\tLL: " + logLikelihood);
                    }
                }
            }
            return(false);
        }
        /// <summary>The slowest, but obviously correct way to get log likelihood.</summary>
        /// <remarks>
        /// The slowest, but obviously correct way to get log likelihood. We've already tested the partition function in
        /// the CliqueTreeTest, but in the interest of making things as different as possible to catch any lurking bugs or
        /// numerical issues, we use the brute force approach here.
        /// </remarks>
        /// <param name="model">the model to get the log-likelihood of, assumes labels for assignments</param>
        /// <param name="weights">the weights to get the log-likelihood at</param>
        /// <returns>the log-likelihood</returns>
        private double LogLikelihood(GraphicalModel model, ConcatVector weights)
        {
            ICollection <TableFactor> tableFactors = model.factors.Stream().Map(null).Collect(Collectors.ToSet());

            System.Diagnostics.Debug.Assert((tableFactors.Count == model.factors.Count));
            // this is the super slow but obviously correct way to get global marginals
            TableFactor bruteForce = null;

            foreach (TableFactor factor in tableFactors)
            {
                if (bruteForce == null)
                {
                    bruteForce = factor;
                }
                else
                {
                    bruteForce = bruteForce.Multiply(factor);
                }
            }
            System.Diagnostics.Debug.Assert((bruteForce != null));
            // observe out all variables that have been registered
            TableFactor observed = bruteForce;

            foreach (int n in bruteForce.neighborIndices)
            {
                if (model.GetVariableMetaDataByReference(n).Contains(CliqueTree.VariableObservedValue))
                {
                    int value = System.Convert.ToInt32(model.GetVariableMetaDataByReference(n)[CliqueTree.VariableObservedValue]);
                    if (observed.neighborIndices.Length > 1)
                    {
                        observed = observed.Observe(n, value);
                    }
                    else
                    {
                        // If we've observed everything, then just quit
                        return(0.0);
                    }
                }
            }
            bruteForce = observed;
            // Now we can get a partition function
            double partitionFunction = bruteForce.ValueSum();

            // For now, we'll assume that all the variables are given for training. EM is another problem altogether
            int[] assignment = new int[bruteForce.neighborIndices.Length];
            for (int i = 0; i < assignment.Length; i++)
            {
                System.Diagnostics.Debug.Assert((!model.GetVariableMetaDataByReference(bruteForce.neighborIndices[i]).Contains(CliqueTree.VariableObservedValue)));
                assignment[i] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(bruteForce.neighborIndices[i])[LogLikelihoodDifferentiableFunction.VariableTrainingValue]);
            }
            if (bruteForce.GetAssignmentValue(assignment) == 0 || partitionFunction == 0)
            {
                return(double.NegativeInfinity);
            }
            return(Math.Log(bruteForce.GetAssignmentValue(assignment)) - Math.Log(partitionFunction));
        }
        internal static long CloneBenchmark(ConcatVector vector)
        {
            long before = Runtime.CurrentTimeMillis();

            for (int i = 0; i < 10000000; i++)
            {
                vector.DeepClone();
            }
            return(Runtime.CurrentTimeMillis() - before);
        }
Ejemplo n.º 8
0
        /*
         * @Theory
         * public void testOptimizeLogLikelihoodWithConstraints(AbstractBatchOptimizer optimizer,
         * @ForAll(sampleSize = 5) @From(LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class) GraphicalModel[] dataset,
         * @ForAll(sampleSize = 2) @From(LogLikelihoodFunctionTest.WeightsGenerator.class) ConcatVector initialWeights,
         * @ForAll(sampleSize = 2) @InRange(minDouble = 0.0, maxDouble = 5.0) double l2regularization) throws Exception {
         * Random r = new Random(42);
         *
         * int constraintComponent = r.nextInt(initialWeights.getNumberOfComponents());
         * double constraintValue = r.nextDouble();
         *
         * if (r.nextBoolean()) {
         * optimizer.addSparseConstraint(constraintComponent, 0, constraintValue);
         * } else {
         * optimizer.addDenseConstraint(constraintComponent, new double[]{constraintValue});
         * }
         *
         * // Put in some constraints
         *
         * AbstractDifferentiableFunction<GraphicalModel> ll = new LogLikelihoodDifferentiableFunction();
         * ConcatVector finalWeights = optimizer.optimize(dataset, ll, initialWeights, l2regularization, 1.0e-9, false);
         * System.err.println("Finished optimizing");
         *
         * assertEquals(constraintValue, finalWeights.getValueAt(constraintComponent, 0), 1.0e-9);
         *
         * double logLikelihood = getValueSum(dataset, finalWeights, ll, l2regularization);
         *
         * // Check in a whole bunch of random directions really nearby that there is no nearby point with a higher log
         * // likelihood
         * for (int i = 0; i < 1000; i++) {
         * int size = finalWeights.getNumberOfComponents();
         * ConcatVector randomDirection = new ConcatVector(size);
         * for (int j = 0; j < size; j++) {
         * if (j == constraintComponent) continue;
         * double[] dense = new double[finalWeights.isComponentSparse(j) ? finalWeights.getSparseIndex(j) + 1 : finalWeights.getDenseComponent(j).length];
         * for (int k = 0; k < dense.length; k++) {
         * dense[k] = (r.nextDouble() - 0.5) * 1.0e-3;
         * }
         * randomDirection.setDenseComponent(j, dense);
         * }
         *
         * ConcatVector randomPerturbation = finalWeights.deepClone();
         * randomPerturbation.addVectorInPlace(randomDirection, 1.0);
         *
         * double randomPerturbedLogLikelihood = getValueSum(dataset, randomPerturbation, ll, l2regularization);
         *
         * // Check that we're within a very small margin of error (around 3 decimal places) of the randomly
         * // discovered value
         *
         * if (logLikelihood < randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood)))) {
         * System.err.println("Thought optimal point was: " + logLikelihood);
         * System.err.println("Discovered better point: " + randomPerturbedLogLikelihood);
         * }
         *
         * assertTrue(logLikelihood >= randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood))));
         * }
         * }
         */
        private double GetValueSum <T>(T[] dataset, ConcatVector weights, AbstractDifferentiableFunction <T> fn, double l2regularization)
        {
            double value = 0.0;

            foreach (T t in dataset)
            {
                value += fn.GetSummaryForInstance(t, weights, new ConcatVector(0));
            }
            return((value / dataset.Length) - (weights.DotProduct(weights) * l2regularization));
        }
 public virtual void ApplyToDerivative(ConcatVector derivative)
 {
     if (isSparse)
     {
         derivative.SetSparseComponent(component, index, 0.0);
     }
     else
     {
         derivative.SetDenseComponent(component, new double[] { 0.0 });
     }
 }
 public GradientWorker(AbstractBatchOptimizer.TrainingWorker <T> mainWorker, int threadIdx, int numThreads, IList <T> queue, AbstractDifferentiableFunction <T> fn, ConcatVector weights)
 {
     // This is to help the dynamic re-balancing of work queues
     this.mainWorker = mainWorker;
     this.threadIdx  = threadIdx;
     this.numThreads = numThreads;
     this.queue      = queue;
     this.fn         = fn;
     this.weights    = weights;
     localDerivative = weights.NewEmptyClone();
 }
 public virtual void ApplyToWeights(ConcatVector weights)
 {
     if (isSparse)
     {
         weights.SetSparseComponent(component, index, value);
     }
     else
     {
         weights.SetDenseComponent(component, arr);
     }
 }
 public TrainingWorker(AbstractBatchOptimizer _enclosing, T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet)
 {
     this._enclosing        = _enclosing;
     this.optimizationState = this._enclosing.GetFreshOptimizationState(initialWeights);
     this.weights           = initialWeights.DeepClone();
     this.dataset           = dataset;
     this.fn = fn;
     this.l2regularization          = l2regularization;
     this.convergenceDerivativeNorm = convergenceDerivativeNorm;
     this.quiet = quiet;
 }
        internal static long ConstructionBenchmark(ConcatVectorBenchmark.ConcatVectorConstructionRecord[] records)
        {
            // Then run the ConcatVector parts
            long before = Runtime.CurrentTimeMillis();

            for (int i = 0; i < records.Length; i++)
            {
                ConcatVector v = records[i].Create();
            }
            // Report the union
            return(Runtime.CurrentTimeMillis() - before);
        }
        public virtual void TestCalculateMarginals(GraphicalModel model, ConcatVector weights)
        {
            CliqueTree inference = new CliqueTree(model, weights);

            // This is the basic check that inference works when you first construct the model
            CheckMarginalsAgainstBruteForce((GraphicalModel)model, (ConcatVector)weights, inference);
            // Now we go through several random mutations to the model, and check that everything is still consistent
            Random r = new Random();

            for (int i = 0; i < 10; i++)
            {
                RandomlyMutateGraphicalModel((GraphicalModel)model, r);
                CheckMarginalsAgainstBruteForce((GraphicalModel)model, (ConcatVector)weights, inference);
            }
        }
            // Creates the multivector
            public virtual ConcatVector Create()
            {
                ConcatVector mv = new ConcatVector(componentSizes.Length);

                for (int i = 0; i < componentSizes.Length; i++)
                {
                    if (componentSizes[i] == -1)
                    {
                        mv.SetSparseComponent(i, sparseOffsets[i], sparseValues[i]);
                    }
                    else
                    {
                        mv.SetDenseComponent(i, densePieces[i]);
                    }
                }
                return(mv);
            }
Ejemplo n.º 16
0
        /// <summary>Construct a TableFactor for inference within a model.</summary>
        /// <remarks>
        /// Construct a TableFactor for inference within a model. This is the same as the other constructor, except that the
        /// table is observed out before any unnecessary dot products are done out, so hopefully we dramatically reduce the
        /// number of computations required to calculate the resulting table.
        /// <p>
        /// Each element of the table is given by: t_i = exp(f_i*w)
        /// </remarks>
        /// <param name="weights">the vector to dot product with every element of the factor table</param>
        /// <param name="factor">the feature factor to be multiplied in</param>
        public TableFactor(ConcatVector weights, GraphicalModel.Factor factor, int[] observations)
            : base()
        {
            System.Diagnostics.Debug.Assert((observations.Length == factor.neigborIndices.Length));
            int size = 0;

            foreach (int observation in observations)
            {
                if (observation == -1)
                {
                    size++;
                }
            }
            neighborIndices = new int[size];
            dimensions      = new int[size];
            int[] forwardPointers  = new int[size];
            int[] factorAssignment = new int[factor.neigborIndices.Length];
            int   cursor           = 0;

            for (int i = 0; i < factor.neigborIndices.Length; i++)
            {
                if (observations[i] == -1)
                {
                    neighborIndices[cursor] = factor.neigborIndices[i];
                    dimensions[cursor]      = factor.featuresTable.GetDimensions()[i];
                    forwardPointers[cursor] = i;
                    cursor++;
                }
                else
                {
                    factorAssignment[i] = observations[i];
                }
            }
            System.Diagnostics.Debug.Assert((cursor == size));
            values = new double[CombinatorialNeighborStatesCount()];
            foreach (int[] assn in this)
            {
                for (int i_1 = 0; i_1 < assn.Length; i_1++)
                {
                    factorAssignment[forwardPointers[i_1]] = assn[i_1];
                }
                SetAssignmentLogValue(assn, factor.featuresTable.GetAssignmentValue(factorAssignment).Get().DotProduct(weights));
            }
        }
Ejemplo n.º 17
0
            public override TableFactorTest.PartiallyObservedConstructorData Generate(SourceOfRandomness sourceOfRandomness, IGenerationStatus generationStatus)
            {
                int len = sourceOfRandomness.NextInt(1, 4);
                ICollection <int> taken = new HashSet <int>();

                int[] neighborIndices = new int[len];
                int[] dimensions      = new int[len];
                int[] observations    = new int[len];
                int   numObserved     = 0;

                for (int i = 0; i < len; i++)
                {
                    int j = sourceOfRandomness.NextInt(8);
                    while (taken.Contains(j))
                    {
                        j = sourceOfRandomness.NextInt(8);
                    }
                    taken.Add(j);
                    neighborIndices[i] = j;
                    dimensions[i]      = sourceOfRandomness.NextInt(1, 3);
                    if (sourceOfRandomness.NextBoolean() && numObserved + 1 < dimensions.Length)
                    {
                        observations[i] = sourceOfRandomness.NextInt(dimensions[i]);
                        numObserved++;
                    }
                    else
                    {
                        observations[i] = -1;
                    }
                }
                ConcatVectorTable t = new ConcatVectorTable(dimensions);

                TableFactorTest.ConcatVectorGenerator gen = new TableFactorTest.ConcatVectorGenerator(typeof(ConcatVector));
                foreach (int[] assn in t)
                {
                    ConcatVector vec = gen.Generate(sourceOfRandomness, generationStatus);
                    t.SetAssignmentValue(assn, null);
                }
                TableFactorTest.PartiallyObservedConstructorData data = new TableFactorTest.PartiallyObservedConstructorData();
                data.factor       = new GraphicalModel.Factor(t, neighborIndices);
                data.observations = observations;
                return(data);
            }
        /// <exception cref="System.IO.IOException"/>
        /// <exception cref="System.TypeLoadException"/>
        internal static ConcatVectorBenchmark.SerializationReport ProtoSerializationBenchmark(ConcatVectorBenchmark.ConcatVectorConstructionRecord[] records)
        {
            ConcatVector[]        vectors = MakeVectors(records);
            ByteArrayOutputStream bArr    = new ByteArrayOutputStream();
            long before = Runtime.CurrentTimeMillis();

            for (int i = 0; i < vectors.Length; i++)
            {
                vectors[i].WriteToStream(bArr);
            }
            bArr.Close();
            byte[] bytes = bArr.ToByteArray();
            ByteArrayInputStream bArrIn = new ByteArrayInputStream(bytes);

            for (int i_1 = 0; i_1 < vectors.Length; i_1++)
            {
                ConcatVector.ReadFromStream(bArrIn);
            }
            ConcatVectorBenchmark.SerializationReport sr = new ConcatVectorBenchmark.SerializationReport();
            sr.time = Runtime.CurrentTimeMillis() - before;
            sr.size = bytes.Length;
            return(sr);
        }
Ejemplo n.º 19
0
        public virtual void TestOptimizeLogLikelihood(AbstractBatchOptimizer optimizer, GraphicalModel[] dataset, ConcatVector initialWeights, double l2regularization)
        {
            AbstractDifferentiableFunction <GraphicalModel> ll = new LogLikelihoodDifferentiableFunction();
            ConcatVector finalWeights = optimizer.Optimize((GraphicalModel[])dataset, ll, (ConcatVector)initialWeights, (double)l2regularization, 1.0e-9, true);

            System.Console.Error.WriteLine("Finished optimizing");
            double logLikelihood = GetValueSum((GraphicalModel[])dataset, finalWeights, ll, (double)l2regularization);
            // Check in a whole bunch of random directions really nearby that there is no nearby point with a higher log
            // likelihood
            Random r = new Random(42);

            for (int i = 0; i < 1000; i++)
            {
                int          size            = finalWeights.GetNumberOfComponents();
                ConcatVector randomDirection = new ConcatVector(size);
                for (int j = 0; j < size; j++)
                {
                    double[] dense = new double[finalWeights.IsComponentSparse(j) ? finalWeights.GetSparseIndex(j) + 1 : finalWeights.GetDenseComponent(j).Length];
                    for (int k = 0; k < dense.Length; k++)
                    {
                        dense[k] = (r.NextDouble() - 0.5) * 1.0e-3;
                    }
                    randomDirection.SetDenseComponent(j, dense);
                }
                ConcatVector randomPerturbation = finalWeights.DeepClone();
                randomPerturbation.AddVectorInPlace(randomDirection, 1.0);
                double randomPerturbedLogLikelihood = GetValueSum((GraphicalModel[])dataset, randomPerturbation, ll, (double)l2regularization);
                // Check that we're within a very small margin of error (around 3 decimal places) of the randomly
                // discovered value
                if (logLikelihood < randomPerturbedLogLikelihood - (1.0e-3 * Math.Max(1.0, Math.Abs(logLikelihood))))
                {
                    System.Console.Error.WriteLine("Thought optimal point was: " + logLikelihood);
                    System.Console.Error.WriteLine("Discovered better point: " + randomPerturbedLogLikelihood);
                }
                NUnit.Framework.Assert.IsTrue(logLikelihood >= randomPerturbedLogLikelihood - (1.0e-3 * Math.Max(1.0, Math.Abs(logLikelihood))));
            }
        }
 private void RandomlyMutateGraphicalModel(GraphicalModel model, Random r)
 {
     if (r.NextBoolean() && model.factors.Count > 1)
     {
         // Remove one factor at random
         model.factors.Remove(Sharpen.Collections.ToArray(model.factors, new GraphicalModel.Factor[model.factors.Count])[r.NextInt(model.factors.Count)]);
     }
     else
     {
         // Add a simple binary factor, attaching a variable we haven't touched yet, but do observe, to an
         // existing variable. This represents the human observation operation in LENSE
         int maxVar        = 0;
         int attachVar     = -1;
         int attachVarSize = 0;
         foreach (GraphicalModel.Factor f in model.factors)
         {
             for (int j = 0; j < f.neigborIndices.Length; j++)
             {
                 int k = f.neigborIndices[j];
                 if (k > maxVar)
                 {
                     maxVar = k;
                 }
                 if (r.NextDouble() > 0.3 || attachVar == -1)
                 {
                     attachVar     = k;
                     attachVarSize = f.featuresTable.GetDimensions()[j];
                 }
             }
         }
         int newVar     = maxVar + 1;
         int newVarSize = 1 + r.NextInt(2);
         if (maxVar >= 8)
         {
             bool[] seenVariables = new bool[maxVar + 1];
             foreach (GraphicalModel.Factor f_1 in model.factors)
             {
                 foreach (int n in f_1.neigborIndices)
                 {
                     seenVariables[n] = true;
                 }
             }
             for (int j = 0; j < seenVariables.Length; j++)
             {
                 if (!seenVariables[j])
                 {
                     newVar = j;
                     break;
                 }
             }
             // This means the model is already too gigantic to be tractable, so we don't add anything here
             if (newVar == maxVar + 1)
             {
                 return;
             }
         }
         if (model.GetVariableMetaDataByReference(newVar).Contains(CliqueTree.VariableObservedValue))
         {
             int assignment = System.Convert.ToInt32(model.GetVariableMetaDataByReference(newVar)[CliqueTree.VariableObservedValue]);
             if (assignment >= newVarSize)
             {
                 newVarSize = assignment + 1;
             }
         }
         GraphicalModel.Factor binary = model.AddFactor(new int[] { newVar, attachVar }, new int[] { newVarSize, attachVarSize }, null);
         // "Cook" the randomly generated feature vector thunks, so they don't change as we run the system
         foreach (int[] assignment_1 in binary.featuresTable)
         {
             ConcatVector randomlyGenerated = binary.featuresTable.GetAssignmentValue(assignment_1).Get();
             binary.featuresTable.SetAssignmentValue(assignment_1, null);
         }
     }
 }
        public virtual void CheckMAPAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference)
        {
            int[] map = inference.CalculateMAP();
            ICollection <TableFactor> tableFactors = model.factors.Stream().Map(null).Collect(Collectors.ToSet());
            // this is the super slow but obviously correct way to get global marginals
            TableFactor bruteForce = null;

            foreach (TableFactor factor in tableFactors)
            {
                if (bruteForce == null)
                {
                    bruteForce = factor;
                }
                else
                {
                    bruteForce = bruteForce.Multiply(factor);
                }
            }
            System.Diagnostics.Debug.Assert((bruteForce != null));
            // observe out all variables that have been registered
            TableFactor observed = bruteForce;

            foreach (int n in bruteForce.neighborIndices)
            {
                if (model.GetVariableMetaDataByReference(n).Contains(CliqueTree.VariableObservedValue))
                {
                    int value = System.Convert.ToInt32(model.GetVariableMetaDataByReference(n)[CliqueTree.VariableObservedValue]);
                    if (observed.neighborIndices.Length > 1)
                    {
                        observed = observed.Observe(n, value);
                    }
                    else
                    {
                        // If we've observed everything, then just quit
                        return;
                    }
                }
            }
            bruteForce = observed;
            int largestVariableNum = 0;

            foreach (GraphicalModel.Factor f in model.factors)
            {
                foreach (int i in f.neigborIndices)
                {
                    if (i > largestVariableNum)
                    {
                        largestVariableNum = i;
                    }
                }
            }
            // this is presented in true order, where 0 corresponds to var 0
            int[] mapValueAssignment = new int[largestVariableNum + 1];
            // this is kept in the order that the factor presents to us
            int[] highestValueAssignment = new int[bruteForce.neighborIndices.Length];
            foreach (int[] assignment in bruteForce)
            {
                if (bruteForce.GetAssignmentValue(assignment) > bruteForce.GetAssignmentValue(highestValueAssignment))
                {
                    highestValueAssignment = assignment;
                    for (int i = 0; i < assignment.Length; i++)
                    {
                        mapValueAssignment[bruteForce.neighborIndices[i]] = assignment[i];
                    }
                }
            }
            int[] forcedAssignments = new int[largestVariableNum + 1];
            for (int i_1 = 0; i_1 < mapValueAssignment.Length; i_1++)
            {
                if (model.GetVariableMetaDataByReference(i_1).Contains(CliqueTree.VariableObservedValue))
                {
                    mapValueAssignment[i_1] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(i_1)[CliqueTree.VariableObservedValue]);
                    forcedAssignments[i_1]  = mapValueAssignment[i_1];
                }
            }
            if (!Arrays.Equals(mapValueAssignment, map))
            {
                System.Console.Error.WriteLine("---");
                System.Console.Error.WriteLine("Relevant variables: " + Arrays.ToString(bruteForce.neighborIndices));
                System.Console.Error.WriteLine("Var Sizes: " + Arrays.ToString(bruteForce.GetDimensions()));
                System.Console.Error.WriteLine("MAP: " + Arrays.ToString(map));
                System.Console.Error.WriteLine("Brute force map: " + Arrays.ToString(mapValueAssignment));
                System.Console.Error.WriteLine("Forced assignments: " + Arrays.ToString(forcedAssignments));
            }
            foreach (int i_2 in bruteForce.neighborIndices)
            {
                // Only check defined variables
                NUnit.Framework.Assert.AreEqual(mapValueAssignment[i_2], map[i_2]);
            }
        }
        private void CheckMarginalsAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference)
        {
            CliqueTree.MarginalResult result       = inference.CalculateMarginals();
            double[][] marginals                   = result.marginals;
            ICollection <TableFactor> tableFactors = model.factors.Stream().Map(null).Collect(Collectors.ToSet());

            System.Diagnostics.Debug.Assert((tableFactors.Count == model.factors.Count));
            // this is the super slow but obviously correct way to get global marginals
            TableFactor bruteForce = null;

            foreach (TableFactor factor in tableFactors)
            {
                if (bruteForce == null)
                {
                    bruteForce = factor;
                }
                else
                {
                    bruteForce = bruteForce.Multiply(factor);
                }
            }
            if (bruteForce != null)
            {
                // observe out all variables that have been registered
                TableFactor observed = bruteForce;
                for (int i = 0; i < bruteForce.neighborIndices.Length; i++)
                {
                    int n = bruteForce.neighborIndices[i];
                    if (model.GetVariableMetaDataByReference(n).Contains(CliqueTree.VariableObservedValue))
                    {
                        int value = System.Convert.ToInt32(model.GetVariableMetaDataByReference(n)[CliqueTree.VariableObservedValue]);
                        // Check that the marginals reflect the observation
                        for (int j = 0; j < marginals[n].Length; j++)
                        {
                            NUnit.Framework.Assert.AreEqual(marginals[n][j], 1.0e-9, j == value ? 1.0 : 0.0);
                        }
                        if (observed.neighborIndices.Length > 1)
                        {
                            observed = observed.Observe(n, value);
                        }
                        else
                        {
                            // If we've observed everything, then just quit
                            return;
                        }
                    }
                }
                bruteForce = observed;
                // Spot check each of the marginals in the brute force calculation
                double[][] bruteMarginals = bruteForce.GetSummedMarginals();
                int        index          = 0;
                foreach (int i_1 in bruteForce.neighborIndices)
                {
                    bool     isEqual = true;
                    double[] brute   = bruteMarginals[index];
                    index++;
                    System.Diagnostics.Debug.Assert((brute != null));
                    System.Diagnostics.Debug.Assert((marginals[i_1] != null));
                    for (int j = 0; j < brute.Length; j++)
                    {
                        if (double.IsNaN(brute[j]))
                        {
                            isEqual = false;
                            break;
                        }
                        if (Math.Abs(brute[j] - marginals[i_1][j]) > 3.0e-2)
                        {
                            isEqual = false;
                            break;
                        }
                    }
                    if (!isEqual)
                    {
                        System.Console.Error.WriteLine("Arrays not equal! Variable " + i_1);
                        System.Console.Error.WriteLine("\tGold: " + Arrays.ToString(brute));
                        System.Console.Error.WriteLine("\tResult: " + Arrays.ToString(marginals[i_1]));
                    }
                    Assert.AssertArrayEquals(marginals[i_1], 3.0e-2, brute);
                }
                // Spot check the partition function
                double goldPartitionFunction = bruteForce.ValueSum();
                // Correct to within 3%
                NUnit.Framework.Assert.AreEqual(result.partitionFunction, goldPartitionFunction * 3.0e-2, goldPartitionFunction);
                // Check the joint marginals
                foreach (GraphicalModel.Factor f in model.factors)
                {
                    NUnit.Framework.Assert.IsTrue(result.jointMarginals.Contains(f));
                    TableFactor bruteForceJointMarginal = bruteForce;
                    foreach (int n in bruteForce.neighborIndices)
                    {
                        foreach (int i_2 in f.neigborIndices)
                        {
                            if (i_2 == n)
                            {
                                goto outer_continue;
                            }
                        }
                        if (bruteForceJointMarginal.neighborIndices.Length > 1)
                        {
                            bruteForceJointMarginal = bruteForceJointMarginal.SumOut(n);
                        }
                        else
                        {
                            int[] fixedAssignment = new int[f.neigborIndices.Length];
                            for (int i_3 = 0; i_3 < fixedAssignment.Length; i_3++)
                            {
                                fixedAssignment[i_3] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(f.neigborIndices[i_3])[CliqueTree.VariableObservedValue]);
                            }
                            foreach (int[] assn in result.jointMarginals[f])
                            {
                                if (Arrays.Equals(assn, fixedAssignment))
                                {
                                    NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assn), 1.0e-7, 1.0);
                                }
                                else
                                {
                                    if (result.jointMarginals[f].GetAssignmentValue(assn) != 0)
                                    {
                                        TableFactor j = result.jointMarginals[f];
                                        foreach (int[] assignment in j)
                                        {
                                            System.Console.Error.WriteLine(Arrays.ToString(assignment) + ": " + j.GetAssignmentValue(assignment));
                                        }
                                    }
                                    NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assn), 1.0e-7, 0.0);
                                }
                            }
                            goto marginals_continue;
                        }
                    }
                    outer_break :;
                    // Find the correspondence between the brute force joint marginal, which may be missing variables
                    // because they were observed out of the table, and the output joint marginals, which are always an exact
                    // match for the original factor
                    int[] backPointers  = new int[f.neigborIndices.Length];
                    int[] observedValue = new int[f.neigborIndices.Length];
                    for (int i_4 = 0; i_4 < backPointers.Length; i_4++)
                    {
                        if (model.GetVariableMetaDataByReference(f.neigborIndices[i_4]).Contains(CliqueTree.VariableObservedValue))
                        {
                            observedValue[i_4] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(f.neigborIndices[i_4])[CliqueTree.VariableObservedValue]);
                            backPointers[i_4]  = -1;
                        }
                        else
                        {
                            observedValue[i_4] = -1;
                            backPointers[i_4]  = -1;
                            for (int j = 0; j < bruteForceJointMarginal.neighborIndices.Length; j++)
                            {
                                if (bruteForceJointMarginal.neighborIndices[j] == f.neigborIndices[i_4])
                                {
                                    backPointers[i_4] = j;
                                }
                            }
                            System.Diagnostics.Debug.Assert((backPointers[i_4] != -1));
                        }
                    }
                    double sum = bruteForceJointMarginal.ValueSum();
                    if (sum == 0.0)
                    {
                        sum = 1;
                    }
                    foreach (int[] assignment_1 in result.jointMarginals[f])
                    {
                        int[] bruteForceMarginalAssignment = new int[bruteForceJointMarginal.neighborIndices.Length];
                        for (int i_2 = 0; i_2 < assignment_1.Length; i_2++)
                        {
                            if (backPointers[i_2] != -1)
                            {
                                bruteForceMarginalAssignment[backPointers[i_2]] = assignment_1[i_2];
                            }
                            else
                            {
                                // Make sure all assignments that don't square with observations get 0 weight
                                System.Diagnostics.Debug.Assert((observedValue[i_2] != -1));
                                if (assignment_1[i_2] != observedValue[i_2])
                                {
                                    if (result.jointMarginals[f].GetAssignmentValue(assignment_1) != 0)
                                    {
                                        System.Console.Error.WriteLine("Joint marginals: " + Arrays.ToString(result.jointMarginals[f].neighborIndices));
                                        System.Console.Error.WriteLine("Assignment: " + Arrays.ToString(assignment_1));
                                        System.Console.Error.WriteLine("Observed Value: " + Arrays.ToString(observedValue));
                                        foreach (int[] assn in result.jointMarginals[f])
                                        {
                                            System.Console.Error.WriteLine("\t" + Arrays.ToString(assn) + ":" + result.jointMarginals[f].GetAssignmentValue(assn));
                                        }
                                    }
                                    NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assignment_1), 1.0e-7, 0.0);
                                    goto outer_continue;
                                }
                            }
                        }
                        NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assignment_1), 1.0e-3, bruteForceJointMarginal.GetAssignmentValue(bruteForceMarginalAssignment) / sum);
                    }
                    outer_break :;
                }
                marginals_break :;
            }
            else
            {
                foreach (double[] marginal in marginals)
                {
                    foreach (double d in marginal)
                    {
                        NUnit.Framework.Assert.AreEqual(d, 3.0e-2, 1.0 / marginal.Length);
                    }
                }
            }
        }
        /// <exception cref="System.IO.IOException"/>
        /// <exception cref="System.TypeLoadException"/>
        public static void Main(string[] args)
        {
            //////////////////////////////////////////////////////////////
            // Generate the CoNLL CliqueTrees to use during gameplay
            //////////////////////////////////////////////////////////////
            CoNLLBenchmark coNLL = new CoNLLBenchmark();
            IList <CoNLLBenchmark.CoNLLSentence> train   = coNLL.GetSentences(DataPath + "conll.iob.4class.train");
            IList <CoNLLBenchmark.CoNLLSentence> testA   = coNLL.GetSentences(DataPath + "conll.iob.4class.testa");
            IList <CoNLLBenchmark.CoNLLSentence> testB   = coNLL.GetSentences(DataPath + "conll.iob.4class.testb");
            IList <CoNLLBenchmark.CoNLLSentence> allData = new List <CoNLLBenchmark.CoNLLSentence>();

            Sharpen.Collections.AddAll(allData, train);
            Sharpen.Collections.AddAll(allData, testA);
            Sharpen.Collections.AddAll(allData, testB);
            ICollection <string> tagsSet = new HashSet <string>();

            foreach (CoNLLBenchmark.CoNLLSentence sentence in allData)
            {
                foreach (string nerTag in sentence.ner)
                {
                    tagsSet.Add(nerTag);
                }
            }
            IList <string> tags = new List <string>();

            Sharpen.Collections.AddAll(tags, tagsSet);
            coNLL.embeddings = coNLL.GetEmbeddings(DataPath + "google-300-trimmed.ser.gz", allData);
            log.Info("Making the training set...");
            ConcatVectorNamespace @namespace = new ConcatVectorNamespace();
            int trainSize = train.Count;

            GraphicalModel[] trainingSet = new GraphicalModel[trainSize];
            for (int i = 0; i < trainSize; i++)
            {
                if (i % 10 == 0)
                {
                    log.Info(i + "/" + trainSize);
                }
                trainingSet[i] = coNLL.GenerateSentenceModel(@namespace, train[i], tags);
            }
            //////////////////////////////////////////////////////////////
            // Generate the random human observation feature vectors that we'll use
            //////////////////////////////////////////////////////////////
            Random r             = new Random(10);
            int    numFeatures   = 5;
            int    featureLength = 30;

            ConcatVector[] humanFeatureVectors = new ConcatVector[1000];
            for (int i_1 = 0; i_1 < humanFeatureVectors.Length; i_1++)
            {
                humanFeatureVectors[i_1] = new ConcatVector(numFeatures);
                for (int j = 0; j < numFeatures; j++)
                {
                    if (r.NextBoolean())
                    {
                        humanFeatureVectors[i_1].SetSparseComponent(j, r.NextInt(featureLength), r.NextDouble());
                    }
                    else
                    {
                        double[] dense = new double[featureLength];
                        for (int k = 0; k < dense.Length; k++)
                        {
                            dense[k] = r.NextDouble();
                        }
                        humanFeatureVectors[i_1].SetDenseComponent(j, dense);
                    }
                }
            }
            ConcatVector weights = new ConcatVector(numFeatures);

            for (int i_2 = 0; i_2 < numFeatures; i_2++)
            {
                double[] dense = new double[featureLength];
                for (int j = 0; j < dense.Length; j++)
                {
                    dense[j] = r.NextDouble();
                }
                weights.SetDenseComponent(i_2, dense);
            }
            //////////////////////////////////////////////////////////////
            // Actually perform gameplay-like random mutations
            //////////////////////////////////////////////////////////////
            log.Info("Warming up the JIT...");
            for (int i_3 = 0; i_3 < 10; i_3++)
            {
                log.Info(i_3);
                Gameplay(r, trainingSet[i_3], weights, humanFeatureVectors);
            }
            log.Info("Timing actual run...");
            long start = Runtime.CurrentTimeMillis();

            for (int i_4 = 0; i_4 < 10; i_4++)
            {
                log.Info(i_4);
                Gameplay(r, trainingSet[i_4], weights, humanFeatureVectors);
            }
            long duration = Runtime.CurrentTimeMillis() - start;

            log.Info("Duration: " + duration);
        }
        //////////////////////////////////////////////////////////////
        // This is an implementation of something like MCTS, trying to take advantage of the general speed gains due to fast
        // CliqueTree caching of dot products. It doesn't actually do any clever selection, preferring to select observations
        // at random.
        //////////////////////////////////////////////////////////////
        private static void Gameplay(Random r, GraphicalModel model, ConcatVector weights, ConcatVector[] humanFeatureVectors)
        {
            IList <int> variablesList     = new List <int>();
            IList <int> variableSizesList = new List <int>();

            foreach (GraphicalModel.Factor f in model.factors)
            {
                for (int i = 0; i < f.neigborIndices.Length; i++)
                {
                    int j = f.neigborIndices[i];
                    if (!variablesList.Contains(j))
                    {
                        variablesList.Add(j);
                        variableSizesList.Add(f.featuresTable.GetDimensions()[i]);
                    }
                }
            }
            int[] variables     = variablesList.Stream().MapToInt(null).ToArray();
            int[] variableSizes = variableSizesList.Stream().MapToInt(null).ToArray();
            IList <GamePlayerBenchmark.SampleState> childrenOfRoot = new List <GamePlayerBenchmark.SampleState>();
            CliqueTree tree           = new CliqueTree(model, weights);
            int        initialFactors = model.factors.Count;
            // Run some "samples"
            long start         = Runtime.CurrentTimeMillis();
            long marginalsTime = 0;

            for (int i_1 = 0; i_1 < 1000; i_1++)
            {
                log.Info("\tTaking sample " + i_1);
                Stack <GamePlayerBenchmark.SampleState> stack = new Stack <GamePlayerBenchmark.SampleState>();
                GamePlayerBenchmark.SampleState         state = SelectOrCreateChildAtRandom(r, model, variables, variableSizes, childrenOfRoot, humanFeatureVectors);
                long localMarginalsTime = 0;
                // Each "sample" is 10 moves deep
                for (int j = 0; j < 10; j++)
                {
                    // log.info("\t\tFrame "+j);
                    state.Push(model);
                    System.Diagnostics.Debug.Assert((model.factors.Count == initialFactors + j + 1));
                    ///////////////////////////////////////////////////////////
                    // This is the thing we're really benchmarking
                    ///////////////////////////////////////////////////////////
                    if (state.cachedMarginal == null)
                    {
                        long s = Runtime.CurrentTimeMillis();
                        state.cachedMarginal = tree.CalculateMarginalsJustSingletons();
                        localMarginalsTime  += Runtime.CurrentTimeMillis() - s;
                    }
                    stack.Push(state);
                    state = SelectOrCreateChildAtRandom(r, model, variables, variableSizes, state.children, humanFeatureVectors);
                }
                log.Info("\t\t" + localMarginalsTime + " ms");
                marginalsTime += localMarginalsTime;
                while (!stack.Empty())
                {
                    stack.Pop().Pop(model);
                }
                System.Diagnostics.Debug.Assert((model.factors.Count == initialFactors));
            }
            log.Info("Marginals time: " + marginalsTime + " ms");
            log.Info("Avg time per marginal: " + (marginalsTime / 200) + " ms");
            log.Info("Total time: " + (Runtime.CurrentTimeMillis() - start));
        }
Ejemplo n.º 25
0
		public virtual void TestConstructWithObservations(TableFactorTest.PartiallyObservedConstructorData data, ConcatVector weights)
		{
			int[] obsArray = new int[9];
			for (int i = 0; i < obsArray.Length; i++)
			{
				obsArray[i] = -1;
			}
			for (int i_1 = 0; i_1 < data.observations.Length; i_1++)
			{
				obsArray[data.factor.neigborIndices[i_1]] = data.observations[i_1];
			}
			TableFactor normalObservations = new TableFactor(weights, data.factor);
			for (int i_2 = 0; i_2 < obsArray.Length; i_2++)
			{
				if (obsArray[i_2] != -1)
				{
					normalObservations = normalObservations.Observe(i_2, obsArray[i_2]);
				}
			}
			TableFactor constructedObservations = new TableFactor(weights, data.factor, data.observations);
			Assert.AssertArrayEquals(normalObservations.neighborIndices, constructedObservations.neighborIndices);
			foreach (int[] assn in normalObservations)
			{
				NUnit.Framework.Assert.AreEqual(constructedObservations.GetAssignmentValue(assn), 1.0e-9, normalObservations.GetAssignmentValue(assn));
			}
		}
        public virtual ConcatVector Optimize <T>(T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet)
        {
            if (!quiet)
            {
                log.Info("\n**************\nBeginning training\n");
            }
            else
            {
                log.Info("[Beginning quiet training]");
            }
            AbstractBatchOptimizer.TrainingWorker <T> mainWorker = new AbstractBatchOptimizer.TrainingWorker <T>(this, dataset, fn, initialWeights, l2regularization, convergenceDerivativeNorm, quiet);
            new Thread(mainWorker).Start();
            BufferedReader br = new BufferedReader(new InputStreamReader(Runtime.@in));

            if (!quiet)
            {
                log.Info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early.");
                log.Info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while");
                log.Info("if left to their own devices.\n");
                while (true)
                {
                    if (mainWorker.isFinished)
                    {
                        log.Info("training completed without interruption");
                        return(mainWorker.weights);
                    }
                    try
                    {
                        if (br.Ready())
                        {
                            log.Info("received quit command: quitting");
                            log.Info("training completed by interruption");
                            mainWorker.isFinished = true;
                            return(mainWorker.weights);
                        }
                    }
                    catch (IOException e)
                    {
                        Sharpen.Runtime.PrintStackTrace(e);
                    }
                }
            }
            else
            {
                while (!mainWorker.isFinished)
                {
                    lock (mainWorker.naturalTerminationBarrier)
                    {
                        try
                        {
                            Sharpen.Runtime.Wait(mainWorker.naturalTerminationBarrier);
                        }
                        catch (Exception e)
                        {
                            throw new RuntimeInterruptedException(e);
                        }
                    }
                }
                log.Info("[Quiet training complete]");
                return(mainWorker.weights);
            }
        }
            public override GraphicalModel Generate(SourceOfRandomness sourceOfRandomness, IGenerationStatus generationStatus)
            {
                GraphicalModel model = new GraphicalModel();

                // Create the variables and factors. These are deliberately tiny so that the brute force approach is tractable
                int[] variableSizes = new int[8];
                for (int i = 0; i < variableSizes.Length; i++)
                {
                    variableSizes[i] = sourceOfRandomness.NextInt(1, 3);
                }
                // Traverse in a randomized BFS to ensure the generated graph is a tree
                if (sourceOfRandomness.NextBoolean())
                {
                    GenerateCliques(variableSizes, new List <int>(), new HashSet <int>(), model, sourceOfRandomness);
                }
                else
                {
                    // Or generate a linear chain CRF, because our random BFS doesn't generate these very often, and they're very
                    // common in practice, so worth testing densely
                    for (int i_1 = 0; i_1 < variableSizes.Length; i_1++)
                    {
                        // Add unary factor
                        GraphicalModel.Factor unary = model.AddFactor(new int[] { i_1 }, new int[] { variableSizes[i_1] }, null);
                        // "Cook" the randomly generated feature vector thunks, so they don't change as we run the system
                        foreach (int[] assignment in unary.featuresTable)
                        {
                            ConcatVector randomlyGenerated = unary.featuresTable.GetAssignmentValue(assignment).Get();
                            unary.featuresTable.SetAssignmentValue(assignment, null);
                        }
                        // Add binary factor
                        if (i_1 < variableSizes.Length - 1)
                        {
                            GraphicalModel.Factor binary = model.AddFactor(new int[] { i_1, i_1 + 1 }, new int[] { variableSizes[i_1], variableSizes[i_1 + 1] }, null);
                            // "Cook" the randomly generated feature vector thunks, so they don't change as we run the system
                            foreach (int[] assignment_1 in binary.featuresTable)
                            {
                                ConcatVector randomlyGenerated = binary.featuresTable.GetAssignmentValue(assignment_1).Get();
                                binary.featuresTable.SetAssignmentValue(assignment_1, null);
                            }
                        }
                    }
                }
                // Add metadata to the variables, factors, and model
                GenerateMetaData(sourceOfRandomness, model.GetModelMetaDataByReference());
                for (int i_2 = 0; i_2 < 20; i_2++)
                {
                    GenerateMetaData(sourceOfRandomness, model.GetVariableMetaDataByReference(i_2));
                }
                foreach (GraphicalModel.Factor factor in model.factors)
                {
                    GenerateMetaData(sourceOfRandomness, factor.GetMetaDataByReference());
                }
                // Observe a few of the variables
                foreach (GraphicalModel.Factor f in model.factors)
                {
                    for (int i_1 = 0; i_1 < f.neigborIndices.Length; i_1++)
                    {
                        if (sourceOfRandomness.NextDouble() > 0.8)
                        {
                            int obs = sourceOfRandomness.NextInt(f.featuresTable.GetDimensions()[i_1]);
                            model.GetVariableMetaDataByReference(f.neigborIndices[i_1])[CliqueTree.VariableObservedValue] = string.Empty + obs;
                        }
                    }
                }
                return(model);
            }
 /// <summary>This is called at the beginning of each batch optimization.</summary>
 /// <remarks>
 /// This is called at the beginning of each batch optimization. It should return a fresh OptimizationState object that
 /// will then be handed to updateWeights() on each update.
 /// </remarks>
 /// <param name="initialWeights">the initial weights for the optimizer to use</param>
 /// <returns>a fresh OptimizationState</returns>
 protected internal abstract AbstractBatchOptimizer.OptimizationState GetFreshOptimizationState(ConcatVector initialWeights);
            private void GenerateCliques(int[] variableSizes, IList <int> startSet, ICollection <int> alreadyRepresented, GraphicalModel model, SourceOfRandomness randomness)
            {
                if (alreadyRepresented.Count == variableSizes.Length)
                {
                    return;
                }
                // Generate the clique variable set
                IList <int> cliqueContents = new List <int>();

                Sharpen.Collections.AddAll(cliqueContents, startSet);
                Sharpen.Collections.AddAll(alreadyRepresented, startSet);
                while (true)
                {
                    if (alreadyRepresented.Count == variableSizes.Length)
                    {
                        break;
                    }
                    if (cliqueContents.Count == 0 || randomness.NextDouble(0, 1) < 0.7)
                    {
                        int gen;
                        do
                        {
                            gen = randomness.NextInt(variableSizes.Length);
                        }while (alreadyRepresented.Contains(gen));
                        alreadyRepresented.Add(gen);
                        cliqueContents.Add(gen);
                    }
                    else
                    {
                        break;
                    }
                }
                // Create the actual table
                int[] neighbors     = new int[cliqueContents.Count];
                int[] neighborSizes = new int[neighbors.Length];
                for (int j = 0; j < neighbors.Length; j++)
                {
                    neighbors[j]     = cliqueContents[j];
                    neighborSizes[j] = variableSizes[neighbors[j]];
                }
                ConcatVectorTable table = new ConcatVectorTable(neighborSizes);

                foreach (int[] assignment in table)
                {
                    // Generate a vector
                    ConcatVector v = new ConcatVector(ConcatVecComponents);
                    for (int x = 0; x < ConcatVecComponents; x++)
                    {
                        if (randomness.NextBoolean())
                        {
                            v.SetSparseComponent(x, randomness.NextInt(32), randomness.NextDouble());
                        }
                        else
                        {
                            double[] val = new double[randomness.NextInt(12)];
                            for (int y = 0; y < val.Length; y++)
                            {
                                val[y] = randomness.NextDouble();
                            }
                            v.SetDenseComponent(x, val);
                        }
                    }
                    // set vec in table
                    table.SetAssignmentValue(assignment, null);
                }
                model.AddFactor(table, neighbors);
                // Pick the number of children
                IList <int> availableVariables = new List <int>();

                Sharpen.Collections.AddAll(availableVariables, cliqueContents);
                availableVariables.RemoveAll(startSet);
                int numChildren = randomness.NextInt(0, availableVariables.Count);

                if (numChildren == 0)
                {
                    return;
                }
                IList <IList <int> > children = new List <IList <int> >();

                for (int i = 0; i < numChildren; i++)
                {
                    children.Add(new List <int>());
                }
                // divide up the shared variables across the children
                int cursor = 0;

                while (true)
                {
                    if (availableVariables.Count == 0)
                    {
                        break;
                    }
                    if (children[cursor].Count == 0 || randomness.NextBoolean())
                    {
                        int gen = randomness.NextInt(availableVariables.Count);
                        children[cursor].Add(availableVariables[gen]);
                        availableVariables.Remove(availableVariables[gen]);
                    }
                    else
                    {
                        break;
                    }
                    cursor = (cursor + 1) % numChildren;
                }
                foreach (IList <int> shared1 in children)
                {
                    foreach (int i_1 in shared1)
                    {
                        foreach (IList <int> shared2 in children)
                        {
                            System.Diagnostics.Debug.Assert((shared1 == shared2 || !shared2.Contains(i_1)));
                        }
                    }
                }
                foreach (IList <int> shared in children)
                {
                    if (shared.Count > 0)
                    {
                        GenerateCliques(variableSizes, shared, alreadyRepresented, model, randomness);
                    }
                }
            }
            public virtual void Run()
            {
                // Multithreading stuff
                int numThreads = Math.Max(1, Runtime.GetRuntime().AvailableProcessors());

                IList <T>[] queues = (IList <T>[])(new IList[numThreads]);
                Random      r      = new Random();

                // Allocate work to make estimated cost of work per thread as even as possible
                if (this.useThreads)
                {
                    for (int i = 0; i < numThreads; i++)
                    {
                        queues[i] = new List <T>();
                    }
                    int[] queueEstimatedTotalCost = new int[numThreads];
                    foreach (T datum in this.dataset)
                    {
                        int datumEstimatedCost = this.EstimateRelativeRuntime(datum);
                        int minCostQueue       = 0;
                        for (int i_1 = 0; i_1 < numThreads; i_1++)
                        {
                            if (queueEstimatedTotalCost[i_1] < queueEstimatedTotalCost[minCostQueue])
                            {
                                minCostQueue = i_1;
                            }
                        }
                        queueEstimatedTotalCost[minCostQueue] += datumEstimatedCost;
                        queues[minCostQueue].Add(datum);
                    }
                }
                while (!this.isFinished)
                {
                    // Collect log-likelihood and derivatives
                    long         startTime     = Runtime.CurrentTimeMillis();
                    long         threadWaiting = 0;
                    ConcatVector derivative    = this.weights.NewEmptyClone();
                    double       logLikelihood = 0.0;
                    if (this.useThreads)
                    {
                        AbstractBatchOptimizer.GradientWorker[] workers = new AbstractBatchOptimizer.GradientWorker[numThreads];
                        Thread[] threads = new Thread[numThreads];
                        for (int i = 0; i < workers.Length; i++)
                        {
                            workers[i]             = new AbstractBatchOptimizer.GradientWorker(this, i, numThreads, queues[i], this.fn, this.weights);
                            threads[i]             = new Thread(workers[i]);
                            workers[i].jvmThreadId = threads[i].GetId();
                            threads[i].Start();
                        }
                        // This is for logging
                        long minFinishTime = long.MaxValue;
                        long maxFinishTime = long.MinValue;
                        // This is for re-balancing
                        long minCPUTime    = long.MaxValue;
                        long maxCPUTime    = long.MinValue;
                        int  slowestWorker = 0;
                        int  fastestWorker = 0;
                        for (int i_1 = 0; i_1 < workers.Length; i_1++)
                        {
                            try
                            {
                                threads[i_1].Join();
                            }
                            catch (Exception e)
                            {
                                throw new RuntimeInterruptedException(e);
                            }
                            logLikelihood += workers[i_1].localLogLikelihood;
                            derivative.AddVectorInPlace(workers[i_1].localDerivative, 1.0);
                            if (workers[i_1].finishedAtTime < minFinishTime)
                            {
                                minFinishTime = workers[i_1].finishedAtTime;
                            }
                            if (workers[i_1].finishedAtTime > maxFinishTime)
                            {
                                maxFinishTime = workers[i_1].finishedAtTime;
                            }
                            if (workers[i_1].cpuTimeRequired < minCPUTime)
                            {
                                fastestWorker = i_1;
                                minCPUTime    = workers[i_1].cpuTimeRequired;
                            }
                            if (workers[i_1].cpuTimeRequired > maxCPUTime)
                            {
                                slowestWorker = i_1;
                                maxCPUTime    = workers[i_1].cpuTimeRequired;
                            }
                        }
                        threadWaiting = maxFinishTime - minFinishTime;
                        // Try to reallocate work dynamically to minimize waiting on subsequent rounds
                        // Figure out the percentage of work represented by the waiting
                        double waitingPercentage = (double)(maxCPUTime - minCPUTime) / (double)maxCPUTime;
                        int    needTransferItems = (int)Math.Floor(queues[slowestWorker].Count * waitingPercentage * 0.5);
                        for (int i_2 = 0; i_2 < needTransferItems; i_2++)
                        {
                            int toTransfer = r.NextInt(queues[slowestWorker].Count);
                            T   datum      = queues[slowestWorker][toTransfer];
                            queues[slowestWorker].Remove(toTransfer);
                            queues[fastestWorker].Add(datum);
                        }
                        // Check for user interrupt
                        if (this.isFinished)
                        {
                            return;
                        }
                    }
                    else
                    {
                        foreach (T datum in this.dataset)
                        {
                            System.Diagnostics.Debug.Assert((datum != null));
                            logLikelihood += this.fn.GetSummaryForInstance(datum, this.weights, derivative);
                            // Check for user interrupt
                            if (this.isFinished)
                            {
                                return;
                            }
                        }
                    }
                    logLikelihood /= this.dataset.Length;
                    derivative.MapInPlace(null);
                    long gradientComputationTime = Runtime.CurrentTimeMillis() - startTime;
                    // Regularization
                    logLikelihood = logLikelihood - (this.l2regularization * this.weights.DotProduct(this.weights));
                    derivative.AddVectorInPlace(this.weights, -2 * this.l2regularization);
                    // Zero out the derivative on the components we're holding fixed
                    foreach (AbstractBatchOptimizer.Constraint constraint in this._enclosing.constraints)
                    {
                        constraint.ApplyToDerivative(derivative);
                    }
                    // If our derivative is sufficiently small, we've converged
                    double derivativeNorm = derivative.DotProduct(derivative);
                    if (derivativeNorm < this.convergenceDerivativeNorm)
                    {
                        if (!this.quiet)
                        {
                            AbstractBatchOptimizer.log.Info("Derivative norm " + derivativeNorm + " < " + this.convergenceDerivativeNorm + ": quitting");
                        }
                        break;
                    }
                    // Do the actual computation
                    if (!this.quiet)
                    {
                        AbstractBatchOptimizer.log.Info("[" + gradientComputationTime + " ms, threads waiting " + threadWaiting + " ms]");
                    }
                    bool converged = this._enclosing.UpdateWeights(this.weights, derivative, logLikelihood, this.optimizationState, this.quiet);
                    // Apply constraints to the weights vector
                    foreach (AbstractBatchOptimizer.Constraint constraint_1 in this._enclosing.constraints)
                    {
                        constraint_1.ApplyToWeights(this.weights);
                    }
                    if (converged)
                    {
                        break;
                    }
                }
                lock (this.naturalTerminationBarrier)
                {
                    Sharpen.Runtime.NotifyAll(this.naturalTerminationBarrier);
                }
                this.isFinished = true;
            }