private static DataFrame GetSearchTermTFIDF(SparkSession spark, string searchTerm, Tokenizer tokenizer, HashingTF hashingTF, IDFModel idfModel) { var searchTermDataFrame = spark.CreateDataFrame(new List <string>() { searchTerm }) .WithColumnRenamed("_1", "Content"); var searchWords = tokenizer.Transform(searchTermDataFrame); var featurizedSeachTerm = hashingTF.Transform(searchWords); var search = idfModel.Transform(featurizedSeachTerm).WithColumnRenamed("features", "features2") .WithColumn("norm2", udfCalcNorm(Col("features2"))); return(search); }
public void TestIDFModel() { int expectedDocFrequency = 1980; string expectedInputCol = "rawFeatures"; string expectedOutputCol = "features"; DataFrame sentenceData = _spark.Sql("SELECT 0.0 as label, 'Hi I heard about Spark' as sentence"); Tokenizer tokenizer = new Tokenizer() .SetInputCol("sentence") .SetOutputCol("words"); DataFrame wordsData = tokenizer.Transform(sentenceData); HashingTF hashingTF = new HashingTF() .SetInputCol("words") .SetOutputCol(expectedInputCol) .SetNumFeatures(20); DataFrame featurizedData = hashingTF.Transform(wordsData); IDF idf = new IDF() .SetInputCol(expectedInputCol) .SetOutputCol(expectedOutputCol) .SetMinDocFreq(expectedDocFrequency); IDFModel idfModel = idf.Fit(featurizedData); DataFrame rescaledData = idfModel.Transform(featurizedData); Assert.Contains(expectedOutputCol, rescaledData.Columns()); Assert.Equal(expectedInputCol, idfModel.GetInputCol()); Assert.Equal(expectedOutputCol, idfModel.GetOutputCol()); Assert.Equal(expectedDocFrequency, idfModel.GetMinDocFreq()); using (var tempDirectory = new TemporaryDirectory()) { string modelPath = Path.Join(tempDirectory.Path, "idfModel"); idfModel.Save(modelPath); IDFModel loadedModel = IDFModel.Load(modelPath); Assert.Equal(idfModel.Uid(), loadedModel.Uid()); } }