/// <summary>
        /// Returns the index with the highest value in the
        /// <paramref name="predictions"/>
        /// matrix.
        /// Indexed from 0.
        /// </summary>
        private static int GetPredictedClass(SimpleMatrix predictions)
        {
            int argmax = 0;

            for (int i = 1; i < predictions.GetNumElements(); ++i)
            {
                if (predictions.Get(i) > predictions.Get(argmax))
                {
                    argmax = i;
                }
            }
            return(argmax);
        }
コード例 #2
0
        public static double[] ParamsToVector(int totalSize, params IEnumerator <SimpleMatrix>[] matrices)
        {
            double[] theta = new double[totalSize];
            int      index = 0;

            foreach (IEnumerator <SimpleMatrix> matrixIterator in matrices)
            {
                while (matrixIterator.MoveNext())
                {
                    SimpleMatrix matrix      = matrixIterator.Current;
                    int          numElements = matrix.GetNumElements();
                    //System.out.println(Integer.toString(numElements)); // to know what matrices are
                    for (int i = 0; i < numElements; ++i)
                    {
                        theta[index] = matrix.Get(i);
                        ++index;
                    }
                }
            }
            if (index != totalSize)
            {
                throw new AssertionError("Did not entirely fill the theta vector: expected " + totalSize + " used " + index);
            }
            return(theta);
        }
コード例 #3
0
 /// <exception cref="System.IO.IOException"/>
 public static void OutputMatrix(BufferedWriter bout, SimpleMatrix matrix)
 {
     for (int i = 0; i < matrix.GetNumElements(); ++i)
     {
         bout.Write("  " + matrix.Get(i));
     }
     bout.NewLine();
 }
コード例 #4
0
        public static IList <double> GetPredictionsAsStringList(Tree tree)
        {
            SimpleMatrix   predictions       = GetPredictions(tree);
            IList <double> listOfPredictions = new List <double>();

            for (int i = 0; i < predictions.NumRows(); i++)
            {
                listOfPredictions.Add(predictions.Get(i));
            }
            return(listOfPredictions);
        }
        private static SimpleTensor GetTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector)
        {
            int          size  = deltaFull.GetNumElements();
            SimpleTensor Wt_df = new SimpleTensor(size * 2, size * 2, size);
            // TODO: combine this concatenation with computeTensorDeltaDown?
            SimpleMatrix fullVector = NeuralUtils.Concatenate(leftVector, rightVector);

            for (int slice = 0; slice < size; ++slice)
            {
                Wt_df.SetSlice(slice, fullVector.Scale(deltaFull.Get(slice)).Mult(fullVector.Transpose()));
            }
            return(Wt_df);
        }
コード例 #6
0
        /// <summary>Returns true iff every element of matrix is 0</summary>
        public static bool IsZero(SimpleMatrix matrix)
        {
            int size = matrix.GetNumElements();

            for (int i = 0; i < size; ++i)
            {
                if (matrix.Get(i) != 0.0)
                {
                    return(false);
                }
            }
            return(true);
        }
コード例 #7
0
        /// <summary>Applies tanh to each of the entries in the matrix.</summary>
        /// <remarks>Applies tanh to each of the entries in the matrix.  Returns a new matrix.</remarks>
        public static SimpleMatrix ElementwiseApplyTanh(SimpleMatrix input)
        {
            SimpleMatrix output = new SimpleMatrix(input);

            for (int i = 0; i < output.NumRows(); ++i)
            {
                for (int j = 0; j < output.NumCols(); ++j)
                {
                    output.Set(i, j, Math.Tanh(output.Get(i, j)));
                }
            }
            return(output);
        }
        private static SimpleMatrix ComputeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector, SimpleMatrix W, SimpleTensor Wt)
        {
            SimpleMatrix WTDelta       = W.Transpose().Mult(deltaFull);
            SimpleMatrix WTDeltaNoBias = WTDelta.ExtractMatrix(0, deltaFull.NumRows() * 2, 0, 1);
            int          size          = deltaFull.GetNumElements();
            SimpleMatrix deltaTensor   = new SimpleMatrix(size * 2, 1);
            SimpleMatrix fullVector    = NeuralUtils.Concatenate(leftVector, rightVector);

            for (int slice = 0; slice < size; ++slice)
            {
                SimpleMatrix scaledFullVector = fullVector.Scale(deltaFull.Get(slice));
                deltaTensor = deltaTensor.Plus(Wt.GetSlice(slice).Plus(Wt.GetSlice(slice).Transpose()).Mult(scaledFullVector));
            }
            return(deltaTensor.Plus(WTDeltaNoBias));
        }
コード例 #9
0
        /// <summary>Applies softmax to all of the elements of the matrix.</summary>
        /// <remarks>
        /// Applies softmax to all of the elements of the matrix.  The return
        /// matrix will have all of its elements sum to 1.  If your matrix is
        /// not already a vector, be sure this is what you actually want.
        /// </remarks>
        public static SimpleMatrix Softmax(SimpleMatrix input)
        {
            SimpleMatrix output = new SimpleMatrix(input);

            for (int i = 0; i < output.NumRows(); ++i)
            {
                for (int j = 0; j < output.NumCols(); ++j)
                {
                    output.Set(i, j, Math.Exp(output.Get(i, j)));
                }
            }
            double sum = output.ElementSum();

            // will be safe, since exp should never return 0
            return(output.Scale(1.0 / sum));
        }
コード例 #10
0
        /// <summary>Return as a double the probability of the predicted class.</summary>
        /// <remarks>
        /// Return as a double the probability of the predicted class. If it is not defined for a node,
        /// it will return -1
        /// </remarks>
        /// <returns>Either the label probability or -1.0 if none</returns>
        public static double GetPredictedClassProb(ILabel label)
        {
            if (!(label is CoreLabel))
            {
                throw new ArgumentException("CoreLabels required to get the attached predicted class probability");
            }
            int          val         = ((CoreLabel)label).Get(typeof(RNNCoreAnnotations.PredictedClass));
            SimpleMatrix predictions = ((CoreLabel)label).Get(typeof(RNNCoreAnnotations.Predictions));

            if (val != null)
            {
                return(predictions.Get(val));
            }
            else
            {
                return(-1.0);
            }
        }
コード例 #11
0
        /// <summary>Outputs the scores from the tree.</summary>
        /// <remarks>
        /// Outputs the scores from the tree.  Counts the tree nodes the
        /// same as setIndexLabels.
        /// </remarks>
        private static int OutputTreeScores(TextWriter @out, Tree tree, int index)
        {
            if (tree.IsLeaf())
            {
                return(index);
            }
            @out.Write("  " + index + ':');
            SimpleMatrix vector = RNNCoreAnnotations.GetPredictions(tree);

            for (int i = 0; i < vector.GetNumElements(); ++i)
            {
                @out.Write("  " + Nf.Format(vector.Get(i)));
            }
            @out.WriteLine();
            index++;
            foreach (Tree child in tree.Children())
            {
                index = OutputTreeScores(@out, child, index);
            }
            return(index);
        }
コード例 #12
0
        /// <exception cref="System.IO.IOException"/>
        public static void Main(string[] args)
        {
            string modelPath = null;
            string outputDir = null;

            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model"))
                {
                    modelPath = args[argIndex + 1];
                    argIndex += 2;
                }
                else
                {
                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output"))
                    {
                        outputDir = args[argIndex + 1];
                        argIndex += 2;
                    }
                    else
                    {
                        log.Info("Unknown argument " + args[argIndex]);
                        Help();
                    }
                }
            }
            if (outputDir == null || modelPath == null)
            {
                Help();
            }
            File outputFile = new File(outputDir);

            FileSystem.CheckNotExistsOrFail(outputFile);
            FileSystem.MkdirOrFail(outputFile);
            LexicalizedParser parser     = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath));
            DVModel           model      = DVParser.GetModelFromLexicalizedParser(parser);
            string            binaryWDir = outputDir + File.separator + "binaryW";

            FileSystem.MkdirOrFail(binaryWDir);
            foreach (TwoDimensionalMap.Entry <string, string, SimpleMatrix> entry in model.binaryTransform)
            {
                string filename = binaryWDir + File.separator + entry.GetFirstKey() + "_" + entry.GetSecondKey() + ".txt";
                DumpMatrix(filename, entry.GetValue());
            }
            string binaryScoreDir = outputDir + File.separator + "binaryScore";

            FileSystem.MkdirOrFail(binaryScoreDir);
            foreach (TwoDimensionalMap.Entry <string, string, SimpleMatrix> entry_1 in model.binaryScore)
            {
                string filename = binaryScoreDir + File.separator + entry_1.GetFirstKey() + "_" + entry_1.GetSecondKey() + ".txt";
                DumpMatrix(filename, entry_1.GetValue());
            }
            string unaryWDir = outputDir + File.separator + "unaryW";

            FileSystem.MkdirOrFail(unaryWDir);
            foreach (KeyValuePair <string, SimpleMatrix> entry_2 in model.unaryTransform)
            {
                string filename = unaryWDir + File.separator + entry_2.Key + ".txt";
                DumpMatrix(filename, entry_2.Value);
            }
            string unaryScoreDir = outputDir + File.separator + "unaryScore";

            FileSystem.MkdirOrFail(unaryScoreDir);
            foreach (KeyValuePair <string, SimpleMatrix> entry_3 in model.unaryScore)
            {
                string filename = unaryScoreDir + File.separator + entry_3.Key + ".txt";
                DumpMatrix(filename, entry_3.Value);
            }
            string         embeddingFile = outputDir + File.separator + "embeddings.txt";
            FileWriter     fout          = new FileWriter(embeddingFile);
            BufferedWriter bout          = new BufferedWriter(fout);

            foreach (KeyValuePair <string, SimpleMatrix> entry_4 in model.wordVectors)
            {
                bout.Write(entry_4.Key);
                SimpleMatrix vector = entry_4.Value;
                for (int i = 0; i < vector.NumRows(); ++i)
                {
                    bout.Write("  " + vector.Get(i, 0));
                }
                bout.Write("\n");
            }
            bout.Close();
            fout.Close();
        }
コード例 #13
0
        /// <exception cref="System.Exception"/>
        public static void Main(string[] args)
        {
            string         modelPath          = null;
            string         outputPath         = null;
            string         testTreebankPath   = null;
            IFileFilter    testTreebankFilter = null;
            IList <string> unusedArgs         = new List <string>();

            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model"))
                {
                    modelPath = args[argIndex + 1];
                    argIndex += 2;
                }
                else
                {
                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank"))
                    {
                        Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank");
                        argIndex           = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                        testTreebankPath   = treebankDescription.First();
                        testTreebankFilter = treebankDescription.Second();
                    }
                    else
                    {
                        if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output"))
                        {
                            outputPath = args[argIndex + 1];
                            argIndex  += 2;
                        }
                        else
                        {
                            unusedArgs.Add(args[argIndex++]);
                        }
                    }
                }
            }
            if (modelPath == null)
            {
                throw new ArgumentException("Need to specify -model");
            }
            if (testTreebankPath == null)
            {
                throw new ArgumentException("Need to specify -testTreebank");
            }
            if (outputPath == null)
            {
                throw new ArgumentException("Need to specify -output");
            }
            string[]          newArgs      = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]);
            LexicalizedParser lexparser    = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs));
            Treebank          testTreebank = null;

            if (testTreebankPath != null)
            {
                log.Info("Reading in trees from " + testTreebankPath);
                if (testTreebankFilter != null)
                {
                    log.Info("Filtering on " + testTreebankFilter);
                }
                testTreebank = lexparser.GetOp().tlpParams.MemoryTreebank();
                testTreebank.LoadPath(testTreebankPath, testTreebankFilter);
                log.Info("Read in " + testTreebank.Count + " trees for testing");
            }
            FileWriter     @out = new FileWriter(outputPath);
            BufferedWriter bout = new BufferedWriter(@out);

            log.Info("Parsing " + testTreebank.Count + " trees");
            int count = 0;
            IList <FindNearestNeighbors.ParseRecord> records = Generics.NewArrayList();

            foreach (Tree goldTree in testTreebank)
            {
                IList <Word> tokens      = goldTree.YieldWords();
                IParserQuery parserQuery = lexparser.ParserQuery();
                if (!parserQuery.Parse(tokens))
                {
                    throw new AssertionError("Could not parse: " + tokens);
                }
                if (!(parserQuery is RerankingParserQuery))
                {
                    throw new ArgumentException("Expected a LexicalizedParser with a Reranker attached");
                }
                RerankingParserQuery rpq = (RerankingParserQuery)parserQuery;
                if (!(rpq.RerankerQuery() is DVModelReranker.Query))
                {
                    throw new ArgumentException("Expected a LexicalizedParser with a DVModel attached");
                }
                DeepTree     tree       = ((DVModelReranker.Query)rpq.RerankerQuery()).GetDeepTrees()[0];
                SimpleMatrix rootVector = null;
                foreach (KeyValuePair <Tree, SimpleMatrix> entry in tree.GetVectors())
                {
                    if (entry.Key.Label().Value().Equals("ROOT"))
                    {
                        rootVector = entry.Value;
                        break;
                    }
                }
                if (rootVector == null)
                {
                    throw new AssertionError("Could not find root nodevector");
                }
                @out.Write(tokens + "\n");
                @out.Write(tree.GetTree() + "\n");
                for (int i = 0; i < rootVector.GetNumElements(); ++i)
                {
                    @out.Write("  " + rootVector.Get(i));
                }
                @out.Write("\n\n\n");
                count++;
                if (count % 10 == 0)
                {
                    log.Info("  " + count);
                }
                records.Add(new FindNearestNeighbors.ParseRecord(tokens, goldTree, tree.GetTree(), rootVector, tree.GetVectors()));
            }
            log.Info("  done parsing");
            IList <Pair <Tree, SimpleMatrix> > subtrees = Generics.NewArrayList();

            foreach (FindNearestNeighbors.ParseRecord record in records)
            {
                foreach (KeyValuePair <Tree, SimpleMatrix> entry in record.nodeVectors)
                {
                    if (entry.Key.GetLeaves().Count <= maxLength)
                    {
                        subtrees.Add(Pair.MakePair(entry.Key, entry.Value));
                    }
                }
            }
            log.Info("There are " + subtrees.Count + " subtrees in the set of trees");
            PriorityQueue <ScoredObject <Pair <Tree, Tree> > > bestmatches = new PriorityQueue <ScoredObject <Pair <Tree, Tree> > >(101, ScoredComparator.DescendingComparator);

            for (int i_1 = 0; i_1 < subtrees.Count; ++i_1)
            {
                log.Info(subtrees[i_1].First().YieldWords());
                log.Info(subtrees[i_1].First());
                for (int j = 0; j < subtrees.Count; ++j)
                {
                    if (i_1 == j)
                    {
                        continue;
                    }
                    // TODO: look at basic category?
                    double normF = subtrees[i_1].Second().Minus(subtrees[j].Second()).NormF();
                    bestmatches.Add(new ScoredObject <Pair <Tree, Tree> >(Pair.MakePair(subtrees[i_1].First(), subtrees[j].First()), normF));
                    if (bestmatches.Count > 100)
                    {
                        bestmatches.Poll();
                    }
                }
                IList <ScoredObject <Pair <Tree, Tree> > > ordered = Generics.NewArrayList();
                while (bestmatches.Count > 0)
                {
                    ordered.Add(bestmatches.Poll());
                }
                Java.Util.Collections.Reverse(ordered);
                foreach (ScoredObject <Pair <Tree, Tree> > pair in ordered)
                {
                    log.Info(" MATCHED " + pair.Object().second.YieldWords() + " ... " + pair.Object().Second() + " with a score of " + pair.Score());
                }
                log.Info();
                log.Info();
                bestmatches.Clear();
            }

            /*
             * for (int i = 0; i < records.size(); ++i) {
             * if (i % 10 == 0) {
             * log.info("  " + i);
             * }
             * List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
             * for (int j = 0; j < records.size(); ++j) {
             * if (i == j) continue;
             *
             * double score = 0.0;
             * int matches = 0;
             * for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
             * for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
             * String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
             * String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
             * if (firstBasic.equals(secondBasic)) {
             ++matches;
             * double normF = first.getValue().minus(second.getValue()).normF();
             * score += normF * normF;
             * }
             * }
             * }
             * if (matches == 0) {
             * score = Double.POSITIVE_INFINITY;
             * } else {
             * score = score / matches;
             * }
             * //double score = records.get(i).vector.minus(records.get(j).vector).normF();
             * scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
             * }
             * Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
             *
             * out.write(records.get(i).sentence.toString() + "\n");
             * for (int j = 0; j < numNeighbors; ++j) {
             * out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
             * }
             * out.write("\n\n");
             * }
             * log.info();
             */
            bout.Flush();
            @out.Flush();
            @out.Close();
        }