Пример #1
0
        public static double[] ParamsToVector(int totalSize, params IEnumerator <SimpleMatrix>[] matrices)
        {
            double[] theta = new double[totalSize];
            int      index = 0;

            foreach (IEnumerator <SimpleMatrix> matrixIterator in matrices)
            {
                while (matrixIterator.MoveNext())
                {
                    SimpleMatrix matrix      = matrixIterator.Current;
                    int          numElements = matrix.GetNumElements();
                    //System.out.println(Integer.toString(numElements)); // to know what matrices are
                    for (int i = 0; i < numElements; ++i)
                    {
                        theta[index] = matrix.Get(i);
                        ++index;
                    }
                }
            }
            if (index != totalSize)
            {
                throw new AssertionError("Did not entirely fill the theta vector: expected " + totalSize + " used " + index);
            }
            return(theta);
        }
 /// <exception cref="System.IO.IOException"/>
 public static void OutputMatrix(BufferedWriter bout, SimpleMatrix matrix)
 {
     for (int i = 0; i < matrix.GetNumElements(); ++i)
     {
         bout.Write("  " + matrix.Get(i));
     }
     bout.NewLine();
 }
        /// <summary>
        /// Returns the index with the highest value in the
        /// <paramref name="predictions"/>
        /// matrix.
        /// Indexed from 0.
        /// </summary>
        private static int GetPredictedClass(SimpleMatrix predictions)
        {
            int argmax = 0;

            for (int i = 1; i < predictions.GetNumElements(); ++i)
            {
                if (predictions.Get(i) > predictions.Get(argmax))
                {
                    argmax = i;
                }
            }
            return(argmax);
        }
        private static SimpleTensor GetTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector)
        {
            int          size  = deltaFull.GetNumElements();
            SimpleTensor Wt_df = new SimpleTensor(size * 2, size * 2, size);
            // TODO: combine this concatenation with computeTensorDeltaDown?
            SimpleMatrix fullVector = NeuralUtils.Concatenate(leftVector, rightVector);

            for (int slice = 0; slice < size; ++slice)
            {
                Wt_df.SetSlice(slice, fullVector.Scale(deltaFull.Get(slice)).Mult(fullVector.Transpose()));
            }
            return(Wt_df);
        }
Пример #5
0
        /// <summary>Returns true iff every element of matrix is 0</summary>
        public static bool IsZero(SimpleMatrix matrix)
        {
            int size = matrix.GetNumElements();

            for (int i = 0; i < size; ++i)
            {
                if (matrix.Get(i) != 0.0)
                {
                    return(false);
                }
            }
            return(true);
        }
        private static SimpleMatrix ComputeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector, SimpleMatrix W, SimpleTensor Wt)
        {
            SimpleMatrix WTDelta       = W.Transpose().Mult(deltaFull);
            SimpleMatrix WTDeltaNoBias = WTDelta.ExtractMatrix(0, deltaFull.NumRows() * 2, 0, 1);
            int          size          = deltaFull.GetNumElements();
            SimpleMatrix deltaTensor   = new SimpleMatrix(size * 2, 1);
            SimpleMatrix fullVector    = NeuralUtils.Concatenate(leftVector, rightVector);

            for (int slice = 0; slice < size; ++slice)
            {
                SimpleMatrix scaledFullVector = fullVector.Scale(deltaFull.Get(slice));
                deltaTensor = deltaTensor.Plus(Wt.GetSlice(slice).Plus(Wt.GetSlice(slice).Transpose()).Mult(scaledFullVector));
            }
            return(deltaTensor.Plus(WTDeltaNoBias));
        }
Пример #7
0
        /// <summary>Outputs the scores from the tree.</summary>
        /// <remarks>
        /// Outputs the scores from the tree.  Counts the tree nodes the
        /// same as setIndexLabels.
        /// </remarks>
        private static int OutputTreeScores(TextWriter @out, Tree tree, int index)
        {
            if (tree.IsLeaf())
            {
                return(index);
            }
            @out.Write("  " + index + ':');
            SimpleMatrix vector = RNNCoreAnnotations.GetPredictions(tree);

            for (int i = 0; i < vector.GetNumElements(); ++i)
            {
                @out.Write("  " + Nf.Format(vector.Get(i)));
            }
            @out.WriteLine();
            index++;
            foreach (Tree child in tree.Children())
            {
                index = OutputTreeScores(@out, child, index);
            }
            return(index);
        }
Пример #8
0
        public static void VectorToParams(double[] theta, params IEnumerator <SimpleMatrix>[] matrices)
        {
            int index = 0;

            foreach (IEnumerator <SimpleMatrix> matrixIterator in matrices)
            {
                while (matrixIterator.MoveNext())
                {
                    SimpleMatrix matrix      = matrixIterator.Current;
                    int          numElements = matrix.GetNumElements();
                    for (int i = 0; i < numElements; ++i)
                    {
                        matrix.Set(i, theta[index]);
                        ++index;
                    }
                }
            }
            if (index != theta.Length)
            {
                throw new AssertionError("Did not entirely use the theta vector");
            }
        }
        public virtual void ReadWordVectors()
        {
            SimpleMatrix unknownNumberVector         = null;
            SimpleMatrix unknownCapsVector           = null;
            SimpleMatrix unknownChineseYearVector    = null;
            SimpleMatrix unknownChineseNumberVector  = null;
            SimpleMatrix unknownChinesePercentVector = null;

            wordVectors = Generics.NewTreeMap();
            int numberCount         = 0;
            int capsCount           = 0;
            int chineseYearCount    = 0;
            int chineseNumberCount  = 0;
            int chinesePercentCount = 0;
            //Map<String, SimpleMatrix> rawWordVectors = NeuralUtils.readRawWordVectors(op.lexOptions.wordVectorFile, op.lexOptions.numHid);
            Embedding rawWordVectors = new Embedding(op.lexOptions.wordVectorFile, op.lexOptions.numHid);

            foreach (string word in rawWordVectors.KeySet())
            {
                SimpleMatrix vector = rawWordVectors.Get(word);
                if (op.wordFunction != null)
                {
                    word = op.wordFunction.Apply(word);
                }
                wordVectors[word] = vector;
                if (op.lexOptions.numHid <= 0)
                {
                    op.lexOptions.numHid = vector.GetNumElements();
                }
                // TODO: factor out all of these identical blobs
                if (op.trainOptions.unknownNumberVector && (NumberPattern.Matcher(word).Matches() || DgPattern.Matcher(word).Matches()))
                {
                    ++numberCount;
                    if (unknownNumberVector == null)
                    {
                        unknownNumberVector = new SimpleMatrix(vector);
                    }
                    else
                    {
                        unknownNumberVector = unknownNumberVector.Plus(vector);
                    }
                }
                if (op.trainOptions.unknownCapsVector && CapsPattern.Matcher(word).Matches())
                {
                    ++capsCount;
                    if (unknownCapsVector == null)
                    {
                        unknownCapsVector = new SimpleMatrix(vector);
                    }
                    else
                    {
                        unknownCapsVector = unknownCapsVector.Plus(vector);
                    }
                }
                if (op.trainOptions.unknownChineseYearVector && ChineseYearPattern.Matcher(word).Matches())
                {
                    ++chineseYearCount;
                    if (unknownChineseYearVector == null)
                    {
                        unknownChineseYearVector = new SimpleMatrix(vector);
                    }
                    else
                    {
                        unknownChineseYearVector = unknownChineseYearVector.Plus(vector);
                    }
                }
                if (op.trainOptions.unknownChineseNumberVector && (ChineseNumberPattern.Matcher(word).Matches() || DgPattern.Matcher(word).Matches()))
                {
                    ++chineseNumberCount;
                    if (unknownChineseNumberVector == null)
                    {
                        unknownChineseNumberVector = new SimpleMatrix(vector);
                    }
                    else
                    {
                        unknownChineseNumberVector = unknownChineseNumberVector.Plus(vector);
                    }
                }
                if (op.trainOptions.unknownChinesePercentVector && ChinesePercentPattern.Matcher(word).Matches())
                {
                    ++chinesePercentCount;
                    if (unknownChinesePercentVector == null)
                    {
                        unknownChinesePercentVector = new SimpleMatrix(vector);
                    }
                    else
                    {
                        unknownChinesePercentVector = unknownChinesePercentVector.Plus(vector);
                    }
                }
            }
            string unkWord = op.trainOptions.unkWord;

            if (op.wordFunction != null)
            {
                unkWord = op.wordFunction.Apply(unkWord);
            }
            SimpleMatrix unknownWordVector = wordVectors[unkWord];

            wordVectors[UnknownWord] = unknownWordVector;
            if (unknownWordVector == null)
            {
                throw new Exception("Unknown word vector not specified in the word vector file");
            }
            if (op.trainOptions.unknownNumberVector)
            {
                if (numberCount > 0)
                {
                    unknownNumberVector = unknownNumberVector.Divide(numberCount);
                }
                else
                {
                    unknownNumberVector = new SimpleMatrix(unknownWordVector);
                }
                wordVectors[UnknownNumber] = unknownNumberVector;
            }
            if (op.trainOptions.unknownCapsVector)
            {
                if (capsCount > 0)
                {
                    unknownCapsVector = unknownCapsVector.Divide(capsCount);
                }
                else
                {
                    unknownCapsVector = new SimpleMatrix(unknownWordVector);
                }
                wordVectors[UnknownCaps] = unknownCapsVector;
            }
            if (op.trainOptions.unknownChineseYearVector)
            {
                log.Info("Matched " + chineseYearCount + " chinese year vectors");
                if (chineseYearCount > 0)
                {
                    unknownChineseYearVector = unknownChineseYearVector.Divide(chineseYearCount);
                }
                else
                {
                    unknownChineseYearVector = new SimpleMatrix(unknownWordVector);
                }
                wordVectors[UnknownChineseYear] = unknownChineseYearVector;
            }
            if (op.trainOptions.unknownChineseNumberVector)
            {
                log.Info("Matched " + chineseNumberCount + " chinese number vectors");
                if (chineseNumberCount > 0)
                {
                    unknownChineseNumberVector = unknownChineseNumberVector.Divide(chineseNumberCount);
                }
                else
                {
                    unknownChineseNumberVector = new SimpleMatrix(unknownWordVector);
                }
                wordVectors[UnknownChineseNumber] = unknownChineseNumberVector;
            }
            if (op.trainOptions.unknownChinesePercentVector)
            {
                log.Info("Matched " + chinesePercentCount + " chinese percent vectors");
                if (chinesePercentCount > 0)
                {
                    unknownChinesePercentVector = unknownChinesePercentVector.Divide(chinesePercentCount);
                }
                else
                {
                    unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
                }
                wordVectors[UnknownChinesePercent] = unknownChinesePercentVector;
            }
            if (op.trainOptions.useContextWords)
            {
                SimpleMatrix start = SimpleMatrix.Random(op.lexOptions.numHid, 1, -0.5, 0.5, rand);
                SimpleMatrix end   = SimpleMatrix.Random(op.lexOptions.numHid, 1, -0.5, 0.5, rand);
                wordVectors[StartWord] = start;
                wordVectors[EndWord]   = end;
            }
        }
        /// <exception cref="System.Exception"/>
        public static void Main(string[] args)
        {
            string         modelPath          = null;
            string         outputPath         = null;
            string         testTreebankPath   = null;
            IFileFilter    testTreebankFilter = null;
            IList <string> unusedArgs         = new List <string>();

            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model"))
                {
                    modelPath = args[argIndex + 1];
                    argIndex += 2;
                }
                else
                {
                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank"))
                    {
                        Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank");
                        argIndex           = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                        testTreebankPath   = treebankDescription.First();
                        testTreebankFilter = treebankDescription.Second();
                    }
                    else
                    {
                        if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output"))
                        {
                            outputPath = args[argIndex + 1];
                            argIndex  += 2;
                        }
                        else
                        {
                            unusedArgs.Add(args[argIndex++]);
                        }
                    }
                }
            }
            if (modelPath == null)
            {
                throw new ArgumentException("Need to specify -model");
            }
            if (testTreebankPath == null)
            {
                throw new ArgumentException("Need to specify -testTreebank");
            }
            if (outputPath == null)
            {
                throw new ArgumentException("Need to specify -output");
            }
            string[]          newArgs      = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]);
            LexicalizedParser lexparser    = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs));
            Treebank          testTreebank = null;

            if (testTreebankPath != null)
            {
                log.Info("Reading in trees from " + testTreebankPath);
                if (testTreebankFilter != null)
                {
                    log.Info("Filtering on " + testTreebankFilter);
                }
                testTreebank = lexparser.GetOp().tlpParams.MemoryTreebank();
                testTreebank.LoadPath(testTreebankPath, testTreebankFilter);
                log.Info("Read in " + testTreebank.Count + " trees for testing");
            }
            FileWriter     @out = new FileWriter(outputPath);
            BufferedWriter bout = new BufferedWriter(@out);

            log.Info("Parsing " + testTreebank.Count + " trees");
            int count = 0;
            IList <FindNearestNeighbors.ParseRecord> records = Generics.NewArrayList();

            foreach (Tree goldTree in testTreebank)
            {
                IList <Word> tokens      = goldTree.YieldWords();
                IParserQuery parserQuery = lexparser.ParserQuery();
                if (!parserQuery.Parse(tokens))
                {
                    throw new AssertionError("Could not parse: " + tokens);
                }
                if (!(parserQuery is RerankingParserQuery))
                {
                    throw new ArgumentException("Expected a LexicalizedParser with a Reranker attached");
                }
                RerankingParserQuery rpq = (RerankingParserQuery)parserQuery;
                if (!(rpq.RerankerQuery() is DVModelReranker.Query))
                {
                    throw new ArgumentException("Expected a LexicalizedParser with a DVModel attached");
                }
                DeepTree     tree       = ((DVModelReranker.Query)rpq.RerankerQuery()).GetDeepTrees()[0];
                SimpleMatrix rootVector = null;
                foreach (KeyValuePair <Tree, SimpleMatrix> entry in tree.GetVectors())
                {
                    if (entry.Key.Label().Value().Equals("ROOT"))
                    {
                        rootVector = entry.Value;
                        break;
                    }
                }
                if (rootVector == null)
                {
                    throw new AssertionError("Could not find root nodevector");
                }
                @out.Write(tokens + "\n");
                @out.Write(tree.GetTree() + "\n");
                for (int i = 0; i < rootVector.GetNumElements(); ++i)
                {
                    @out.Write("  " + rootVector.Get(i));
                }
                @out.Write("\n\n\n");
                count++;
                if (count % 10 == 0)
                {
                    log.Info("  " + count);
                }
                records.Add(new FindNearestNeighbors.ParseRecord(tokens, goldTree, tree.GetTree(), rootVector, tree.GetVectors()));
            }
            log.Info("  done parsing");
            IList <Pair <Tree, SimpleMatrix> > subtrees = Generics.NewArrayList();

            foreach (FindNearestNeighbors.ParseRecord record in records)
            {
                foreach (KeyValuePair <Tree, SimpleMatrix> entry in record.nodeVectors)
                {
                    if (entry.Key.GetLeaves().Count <= maxLength)
                    {
                        subtrees.Add(Pair.MakePair(entry.Key, entry.Value));
                    }
                }
            }
            log.Info("There are " + subtrees.Count + " subtrees in the set of trees");
            PriorityQueue <ScoredObject <Pair <Tree, Tree> > > bestmatches = new PriorityQueue <ScoredObject <Pair <Tree, Tree> > >(101, ScoredComparator.DescendingComparator);

            for (int i_1 = 0; i_1 < subtrees.Count; ++i_1)
            {
                log.Info(subtrees[i_1].First().YieldWords());
                log.Info(subtrees[i_1].First());
                for (int j = 0; j < subtrees.Count; ++j)
                {
                    if (i_1 == j)
                    {
                        continue;
                    }
                    // TODO: look at basic category?
                    double normF = subtrees[i_1].Second().Minus(subtrees[j].Second()).NormF();
                    bestmatches.Add(new ScoredObject <Pair <Tree, Tree> >(Pair.MakePair(subtrees[i_1].First(), subtrees[j].First()), normF));
                    if (bestmatches.Count > 100)
                    {
                        bestmatches.Poll();
                    }
                }
                IList <ScoredObject <Pair <Tree, Tree> > > ordered = Generics.NewArrayList();
                while (bestmatches.Count > 0)
                {
                    ordered.Add(bestmatches.Poll());
                }
                Java.Util.Collections.Reverse(ordered);
                foreach (ScoredObject <Pair <Tree, Tree> > pair in ordered)
                {
                    log.Info(" MATCHED " + pair.Object().second.YieldWords() + " ... " + pair.Object().Second() + " with a score of " + pair.Score());
                }
                log.Info();
                log.Info();
                bestmatches.Clear();
            }

            /*
             * for (int i = 0; i < records.size(); ++i) {
             * if (i % 10 == 0) {
             * log.info("  " + i);
             * }
             * List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
             * for (int j = 0; j < records.size(); ++j) {
             * if (i == j) continue;
             *
             * double score = 0.0;
             * int matches = 0;
             * for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
             * for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
             * String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
             * String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
             * if (firstBasic.equals(secondBasic)) {
             ++matches;
             * double normF = first.getValue().minus(second.getValue()).normF();
             * score += normF * normF;
             * }
             * }
             * }
             * if (matches == 0) {
             * score = Double.POSITIVE_INFINITY;
             * } else {
             * score = score / matches;
             * }
             * //double score = records.get(i).vector.minus(records.get(j).vector).normF();
             * scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
             * }
             * Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
             *
             * out.write(records.get(i).sentence.toString() + "\n");
             * for (int j = 0; j < numNeighbors; ++j) {
             * out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
             * }
             * out.write("\n\n");
             * }
             * log.info();
             */
            bout.Flush();
            @out.Flush();
            @out.Close();
        }