Ejemplo n.º 1
0
 /// <inheritdoc/>
 public bool TrySerializeAndSave(string filePath, object serializeableObject)
 {
     try
     {
         AtomicFileWriter.Write(
             filePath,
             (xmlStream) => {
             XmlUtils.SerializeToXmlStream(serializeableObject, xmlStream, Encoding.UTF8);
         });
         return(true);
     }
     catch (Exception)
     {
         return(false);
     }
 }
Ejemplo n.º 2
0
        /// <summary>
        /// Construct an SPM model from data; that is, train one.
        /// The input is passed as an IEnumerable or a ParallelQuery of lines of raw plain-text.
        /// The model is returned as a binary blob (for later use in encoding/decoding).
        /// Underneath, this uses the spm_train executable, which needs to store the model as a file. That location is
        /// passed in as 'tempSPMModelPath'. These output files are temporary and local to this function, but
        /// it is useful to keep them around for diagnostics and debugging; they are not (meant to be) used after this.
        /// 'minPieceCount' allows to set a minimum observation count for word pieces. spm_train does not support this,
        /// so we emulate/approximate it by running spm_train twice.
        /// </summary>
        public static SentencePieceModel Train <Enumerable>(Enumerable tokenStrings, string tempSPMModelPath,
                                                            SentencePieceTrainConfig spmParams, int minPieceCount, string spmBinDir)
            where Enumerable : IEnumerable <string> // using template so we won't loose parallelism (is this needed?)
        {
            Sanity.Requires(tempSPMModelPath.EndsWith(spmModelExt), $"FactoredSegmenter SentencePiece model path must end in {spmModelExt}");
            var modelPrefix = tempSPMModelPath.Substring(0, tempSPMModelPath.Length - spmModelExt.Length);

#if false   // helper during debugging of final Training stage when models already exist
            LoadSPMModelFiles(modelPrefix, out var spmModelBlob, out var spmVocab);
#else
            // write the tokens to a temp file
            var tempInputDataPath = modelPrefix + ".data";
            Logger.WriteLine($"FactoredSegmenter: Writing to temp file {tempInputDataPath} for SPM training...");
            AtomicFileWriter.Save(tempInputDataPath, tmpPath => File.WriteAllLines(tmpPath, tokenStrings, new UTF8Encoding()));
            // atomic writing allows the impatient user to know when the writing has completed and spm_train has taken over

            // invoke spm_train
            SPMTrain(tempInputDataPath, modelPrefix, spmParams, spmBinDir, null);

            // fetch the content of the generated .model and .vocab file into in-memory data structures
            // After this, the spm_train-generated files are no longer used; and only kept for debugging purposes.
            LoadSPMModelFiles(modelPrefix, out var spmModelBlob, out var spmVocab);

            // enforce minimum piece-count constraint
            if (minPieceCount > 1)
            {
                // encode the SPM training data and count each token's occurence
                Logger.WriteLine($"FactoredSegmenter: Minimum-count constraint ({minPieceCount}), counting SPM tokens...");
                var coder = new SentencePieceCoder(new SentencePieceCoderConfig {
                    SentencePieceModel = new SentencePieceModel(spmModelBlob)
                });
                var counts = CountEncodedTokens(tempInputDataPath, coder);
                File.WriteAllLines(tempSPMModelPath + $".{spmVocab.Length}.counts", // save it for diagnostics only
                                   from kvp in counts orderby - kvp.Value, kvp.Key select $"{kvp.Key}\t{kvp.Value}");
                // count number of SPM vocab items that should be kept (above the threshold or single character which we always keep)
                var spmVocabSet       = new HashSet <string>(spmVocab);
                int adjustedVocabSize = counts.Count(kvp => spmVocabSet.Contains(kvp.Key) && (kvp.Key.Length == 1 || kvp.Value >= minPieceCount));
                // if there are units below the threshold, reduce the SPM vocab size and retrain
                if (adjustedVocabSize < spmVocab.Length)
                {
                    Logger.WriteLine($"FactoredSegmenter: Only {adjustedVocabSize} out of {spmVocab.Length} sentence pieces have {minPieceCount} or more observations." +
                                     $" Retraining SPM model with reduced vocabSize {adjustedVocabSize}");
                    // invoke spm_train a second time
                    SPMTrain(tempInputDataPath, modelPrefix, spmParams, spmBinDir, adjustedVocabSize);
                    LoadSPMModelFiles(modelPrefix, out spmModelBlob, out spmVocab); // reload the new model
                }
                // count once again for diagnostics only
                Logger.WriteLine($"FactoredSegmenter: Re-counting SPM tokens after reduction to {adjustedVocabSize}...");
                coder = new SentencePieceCoder(new SentencePieceCoderConfig {
                    SentencePieceModel = new SentencePieceModel(spmModelBlob)
                });
                counts = counts = CountEncodedTokens(tempInputDataPath, coder);
                File.WriteAllLines(tempSPMModelPath + $".{adjustedVocabSize}.counts", // save for diagnostics only
                                   from kvp in counts orderby - kvp.Value, kvp.Key select $"{kvp.Key}\t{kvp.Value}");
            }

            // delete the temp file   --except if it failed, so user can double-check what's going on
            // commented out temporarily to aid debugging
            //File.Delete(tempPath);
#endif

            return(new SentencePieceModel(spmModelBlob));
        }