/// <summary> /// Write the decision forest and tree index. /// </summary> /// <param name="forest">The decision forest.</param> /// <param name="treeIndexes">Tree indexes.</param> /// <param name="questionIndexes">Question indexes.</param> /// <param name="questionSet">The Question set.</param> /// <param name="namedOffsets">The named Offsets.</param> /// <param name="forestSerializer">The forest serializer.</param> /// <param name="writer">The writer to write.</param> /// <returns>The postion after write.</returns> internal int Write(DecisionForest forest, List<TreeIndex> treeIndexes, Dictionary<string, uint> questionIndexes, HtsQuestionSet questionSet, IDictionary<string, uint[]> namedOffsets, DecisionForestSerializer forestSerializer, DataWriter writer) { Helper.ThrowIfNull(forest); Helper.ThrowIfNull(treeIndexes); Helper.ThrowIfNull(writer); Helper.ThrowIfNull(questionIndexes); Helper.ThrowIfNull(questionSet); int decisionTreeSectionStart = (int)writer.BaseStream.Position; int position = decisionTreeSectionStart; // Write tree index (place holder) position += (int)WriteTreeIndexes(writer, treeIndexes.ToArray()); // Write trees for (int treeIndex = 0; treeIndex < forest.TreeList.Count; treeIndex++) { DecisionTree tree = forest.TreeList[treeIndex]; TreeIndex index = treeIndexes[treeIndex]; index.Offset = position - decisionTreeSectionStart; index.Size = (int)forestSerializer.Write(tree, writer, questionIndexes, namedOffsets); position += index.Size; } // Write tree index using (PositionRecover recover = new PositionRecover(writer, decisionTreeSectionStart, SeekOrigin.Begin)) { WriteTreeIndexes(writer, treeIndexes.ToArray()); } Debug.Assert(position % sizeof(uint) == 0, "Data should be 4-byte aligned."); return position - decisionTreeSectionStart; }
/// <summary> /// Loads the pre-selection data from text file. /// </summary> /// <param name="forestFile">The file name of decision forest.</param> /// <param name="candidateGroupFile">The file name of candidate group data.</param> /// <param name="sentenceSet">The given sentence set where to find candidates.</param> public void LoadFromText(string forestFile, string candidateGroupFile, TrainingSentenceSet sentenceSet) { _sentenceSet = sentenceSet; _decisionForest = new DecisionForest("pre-selection"); _decisionForest.Load(forestFile); using (StreamReader fileReader = new StreamReader(candidateGroupFile)) { while (!fileReader.EndOfStream) { CandidateGroup candidateGroup = new CandidateGroup(); candidateGroup.Load(fileReader, sentenceSet); _nameIndexedCandidateGroup.Add(candidateGroup.Name, candidateGroup); } } // Each leaf node must be in the candidate groups. int countOfLeafNodes = 0; foreach (DecisionTree tree in _decisionForest.TreeList) { countOfLeafNodes += tree.LeafNodeMap.Count; foreach (DecisionTreeNode node in tree.LeafNodeMap.Values) { if (!_nameIndexedCandidateGroup.ContainsKey(node.Name)) { throw new InvalidDataException( Helper.NeutralFormat("Mismatched between file \"{0}\" and \"{1}\"", forestFile, candidateGroupFile)); } } } // Ensure candidate id is continuous and starts with zero. List<int> expected = new List<int>(); for (int i = 0; i < _nameIndexedCandidateGroup.Count; ++i) { expected.Add(i); } if (!Helper.Compare(expected, _nameIndexedCandidateGroup.Select(pair => pair.Value.Id).ToArray(), true)) { throw new InvalidDataException("The candidate group id should be continuous and starts with zero"); } // The count of candidate group must be equal to the count of leaf nodes. if (countOfLeafNodes != _nameIndexedCandidateGroup.Count) { throw new InvalidDataException( Helper.NeutralFormat("Mismatched between file \"{0}\" and \"{1}\"", forestFile, candidateGroupFile)); } }
/// <summary> /// Save pre-selection forest. /// </summary> /// <param name="decisionForest">The forest with each tree corresponding to a unit.</param> /// <param name="candidateGroups">The candidate group collection.</param> /// <param name="unitCandidateNameIds">Given candidate idx.</param> /// <param name="customFeatures">Cusotmized linguistic feature list.</param> /// <param name="outputPath">The output path.</param> public void Write(DecisionForest decisionForest, ICollection<CandidateGroup> candidateGroups, IDictionary<string, int> unitCandidateNameIds, HashSet<string> customFeatures, string outputPath) { foreach (Question question in decisionForest.QuestionList) { question.Language = _phoneSet.Language; question.ValueSetToCodeValueSet(_posSet, _phoneSet, customFeatures); } FileStream file = new FileStream(outputPath, FileMode.Create); try { using (DataWriter writer = new DataWriter(file)) { file = null; uint position = 0; // Write header section place holder PreselectionFileHeader header = new PreselectionFileHeader(); position += (uint)header.Write(writer); HtsFontSerializer serializer = new HtsFontSerializer(); // Write feature, question and prepare string pool HtsQuestionSet questionSet = new HtsQuestionSet { Items = decisionForest.QuestionList, Header = new HtsQuestionSetHeader { HasQuestionName = false }, CustomFeatures = customFeatures, }; using (StringPool stringPool = new StringPool()) { Dictionary<string, uint> questionIndexes = new Dictionary<string, uint>(); header.QuestionOffset = position; header.QuestionSize = serializer.Write( questionSet, writer, stringPool, questionIndexes, customFeatures); position += header.QuestionSize; // Write leaf referenced data to buffer IEnumerable<INodeData> dataNodes = GetCandidateNodes(candidateGroups); using (MemoryStream candidateSetBuffer = new MemoryStream()) { Dictionary<string, int> namedSetOffset = new Dictionary<string, int>(); int candidateSetSize = HtsFontSerializer.Write( dataNodes, new DataWriter(candidateSetBuffer), namedSetOffset); // Write decision forest Dictionary<string, uint[]> namedOffsets = namedSetOffset.ToDictionary(p => p.Key, p => new[] { (uint)p.Value }); header.DecisionTreeSectionOffset = position; header.DecisionTreeSectionSize = (uint)Write(decisionForest, unitCandidateNameIds, questionIndexes, questionSet, namedOffsets, new DecisionForestSerializer(), writer); position += header.DecisionTreeSectionSize; // Write string pool header.StringPoolOffset = position; header.StringPoolSize = HtsFontSerializer.Write(stringPool, writer); position += header.StringPoolSize; // Write leaf referenced data header.CandidateSetSectionOffset = position; header.CandidateSetSectionSize = writer.Write(candidateSetBuffer.ToArray()); position += header.CandidateSetSectionSize; } // Write header section place holder using (PositionRecover recover = new PositionRecover(writer, 0)) { header.Write(writer); } } } } finally { if (null != file) { file.Dispose(); } } }
/// <summary> /// Initializes a new instance of the PreSelectionData class according to given forest and sentenceSet. /// </summary> /// <param name="forest">The given forest.</param> /// <param name="sentenceSet">The given sentence set where to find candiates.</param> /// <param name="fullFeatureNameSet">The full feature set to parse tree.</param> public PreSelectionData(DecisionForest forest, TrainingSentenceSet sentenceSet, LabelFeatureNameSet fullFeatureNameSet) { if (forest == null) { throw new ArgumentNullException("forest"); } if (sentenceSet == null) { throw new ArgumentNullException("sentenceSet"); } if (fullFeatureNameSet == null) { throw new ArgumentNullException("fullFeatureNameSet"); } _decisionForest = forest; _sentenceSet = sentenceSet; _nameIndexedCandidateGroup = new Dictionary<string, CandidateGroup>(); // Create empty candidate group. foreach (DecisionTree tree in forest.TreeList) { foreach (DecisionTreeNode node in tree.LeafNodeMap.Values) { CandidateGroup candidateGroup = new CandidateGroup { Name = node.Name, Id = _nameIndexedCandidateGroup.Count }; _nameIndexedCandidateGroup.Add(candidateGroup.Name, candidateGroup); } } // Travel the training sentence set to find the corresponding candidates. foreach (Sentence sentence in sentenceSet.Sentences.Values) { foreach (UnitCandidate candidate in sentence.Candidates) { if (!candidate.SilenceCandidate) { candidate.Label.FeatureNameSet = fullFeatureNameSet; DecisionTree[] linkedDecisionTrees = forest.TreeList.Where(t => t.Name == candidate.Name).ToArray(); Debug.Assert(linkedDecisionTrees.Length == 1, Helper.NeutralFormat("Invalidated: More than 1 {0} Preselection tree are linked to unit {1}", linkedDecisionTrees.Length, candidate.Name)); DecisionTreeNode leafNode = DecisionForestExtension.FilterTree(linkedDecisionTrees[0].NodeList[0], forest.Questions, candidate.Label); Debug.Assert(leafNode != null, Helper.NeutralFormat("cannot find leaf node for candidate {0} in sentence {1}", candidate.Name, sentence.Id)); _nameIndexedCandidateGroup[leafNode.Name].Candidates.Add(candidate); } } } // Verify there is no empty candidate group. foreach (CandidateGroup candidateGroup in _nameIndexedCandidateGroup.Values) { if (candidateGroup.Candidates.Count <= 0) { throw new InvalidDataException( Helper.NeutralFormat("There is no candidate in candidate group \"{0}\"", candidateGroup.Name)); } } }
/// <summary> /// Combines given forests to a single forest. /// </summary> /// <param name="name">The new forest name.</param> /// <param name="forests">The given forests to be combined.</param> /// <returns>The new forest which contains all the forests.</returns> public static DecisionForest Combine(string name, IEnumerable<DecisionForest> forests) { IEnumerable<DecisionTree> trees = new List<DecisionTree>(); SortedDictionary<string, Question> nameIndexedQuestions = new SortedDictionary<string, Question>(); foreach (DecisionForest forest in forests) { trees = trees.Union(forest.TreeList); foreach (Question question in forest.Questions.Values) { if (!nameIndexedQuestions.ContainsKey(question.Name)) { nameIndexedQuestions.Add(question.Name, question); } else { if (question.Expression != nameIndexedQuestions[question.Name].Expression) { throw new InvalidDataException(Helper.NeutralFormat("question \"{0}\" have two different expressions", question.Name)); } } } } DecisionForest newForest = new DecisionForest(name) { _nameIndexedQuestions = nameIndexedQuestions, _treeList = trees.ToList() }; return newForest; }