/// <summary> /// Computes the prediction given a model as a zip file /// and some data in a view. /// </summary> public static void SavePredictions(IHostEnvironment env, IDataView tr, string outFilePath, IEnumerable <string> subsetColumns = null) { var saver2 = env.CreateSaver("Text"); var columns = GetColumnsIndex(tr.Schema, subsetColumns); using (var fs2 = File.Create(outFilePath)) saver2.SaveData(fs2, tr, columns); }
/// <summary> /// Computes the prediction given a model as a zip file /// and some data in a view. /// </summary> public static void SavePredictions(IHostEnvironment env, string modelPath, string outFilePath, IDataView data, IEnumerable <string> subsetColumns = null) { using (var fs = File.OpenRead(modelPath)) { var deserializedData = env.LoadTransforms(fs, data); var saver2 = env.CreateSaver("Text"); var columns = GetColumnsIndex(data.Schema, subsetColumns); using (var fs2 = File.Create(outFilePath)) saver2.SaveData(fs2, deserializedData, columns); } }
/// <summary> /// Finalizes the test on a predictor, calls the predictor with a scorer, /// saves the data, saves the models, loads it back, saves the data again, /// checks the output is the same. /// </summary> /// <param name="env">environment</param> /// <param name="outModelFilePath">output filename</param> /// <param name="predictor">predictor</param> /// <param name="roles">label, feature, ...</param> /// <param name="outData">first output data</param> /// <param name="outData2">second output data</param> /// <param name="kind">prediction kind</param> /// <param name="checkError">checks errors</param> /// <param name="ratio">check the error is below that threshold (if checkError is true)</param> /// <param name="ratioReadSave">check the predictions difference after reloading the model are below this threshold</param> public static void FinalizeSerializationTest(IHostEnvironment env, string outModelFilePath, IPredictor predictor, RoleMappedData roles, string outData, string outData2, PredictionKind kind, bool checkError = true, float ratio = 0.8f, float ratioReadSave = 0.06f, bool checkType = true) { string labelColumn = kind != PredictionKind.Clustering ? roles.Schema.Label.Value.Name : null; #region save, reading, running // Saves model. using (var ch = env.Start("Save")) using (var fs = File.Create(outModelFilePath)) TrainUtils.SaveModel(env, ch, fs, predictor, roles); if (!File.Exists(outModelFilePath)) { throw new FileNotFoundException(outModelFilePath); } // Loads the model back. using (var fs = File.OpenRead(outModelFilePath)) { var pred_local = ModelFileUtils.LoadPredictorOrNull(env, fs); if (pred_local == null) { throw new Exception(string.Format("Unable to load '{0}'", outModelFilePath)); } if (checkType && predictor.GetType() != pred_local.GetType()) { throw new Exception(string.Format("Type mismatch {0} != {1}", predictor.GetType(), pred_local.GetType())); } } // Checks the outputs. var sch1 = SchemaHelper.ToString(roles.Schema.Schema); var scorer = PredictorHelper.CreateDefaultScorer(env, roles, predictor); var sch2 = SchemaHelper.ToString(scorer.Schema); if (string.IsNullOrEmpty(sch1) || string.IsNullOrEmpty(sch2)) { throw new Exception("Empty schemas"); } var saver = env.CreateSaver("Text"); var columns = new int[scorer.Schema.Count]; for (int i = 0; i < columns.Length; ++i) { columns[i] = saver.IsColumnSavable(scorer.Schema[i].Type) ? i : -1; } columns = columns.Where(c => c >= 0).ToArray(); using (var fs2 = File.Create(outData)) saver.SaveData(fs2, scorer, columns); if (!File.Exists(outModelFilePath)) { throw new FileNotFoundException(outData); } // Check we have the same output. using (var fs = File.OpenRead(outModelFilePath)) { var model = ModelFileUtils.LoadPredictorOrNull(env, fs); scorer = PredictorHelper.CreateDefaultScorer(env, roles, model); saver = env.CreateSaver("Text"); using (var fs2 = File.Create(outData2)) saver.SaveData(fs2, scorer, columns); } var t1 = File.ReadAllLines(outData); var t2 = File.ReadAllLines(outData2); if (t1.Length != t2.Length) { throw new Exception(string.Format("Not the same number of lines: {0} != {1}", t1.Length, t2.Length)); } var linesN = new List <int>(); for (int i = 0; i < t1.Length; ++i) { if (t1[i] != t2[i]) { linesN.Add(i); } } if (linesN.Count > (int)(t1.Length * ratioReadSave)) { var rows = linesN.Select(i => string.Format("1-Mismatch on line {0}/{3}:\n{1}\n{2}", i, t1[i], t2[i], t1.Length)).ToList(); rows.Add($"Number of differences: {linesN.Count}/{t1.Length}"); throw new Exception(string.Join("\n", rows)); } #endregion #region clustering if (kind == PredictionKind.Clustering) { // Nothing to do here. return; } #endregion #region supervized string expectedOuput = kind == PredictionKind.Regression ? "Score" : "PredictedLabel"; // Get label and basic checking about performance. using (var cursor = scorer.GetRowCursor(scorer.Schema)) { int ilabel, ipred; ilabel = SchemaHelper.GetColumnIndex(cursor.Schema, labelColumn); ipred = SchemaHelper.GetColumnIndex(cursor.Schema, expectedOuput); var ty1 = cursor.Schema[ilabel].Type; var ty2 = cursor.Schema[ipred].Type; var dist1 = new Dictionary <int, int>(); var dist2 = new Dictionary <int, int>(); var conf = new Dictionary <Tuple <int, int>, long>(); if (kind == PredictionKind.MulticlassClassification) { #region Multiclass if (!ty2.IsKey()) { throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema))); } if (ty1.RawKind() == DataKind.Single) { var lgetter = cursor.GetGetter <float>(SchemaHelper._dc(ilabel, cursor)); var pgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ipred, cursor)); float ans = 0; uint pre = 0; while (cursor.MoveNext()) { lgetter(ref ans); pgetter(ref pre); // The scorer +1 to the argmax. ++ans; var key = new Tuple <int, int>((int)pre, (int)ans); if (!conf.ContainsKey(key)) { conf[key] = 1; } else { ++conf[key]; } if (!dist1.ContainsKey((int)ans)) { dist1[(int)ans] = 1; } else { ++dist1[(int)ans]; } if (!dist2.ContainsKey((int)pre)) { dist2[(int)pre] = 1; } else { ++dist2[(int)pre]; } } } else if (ty1.RawKind() == DataKind.UInt32 && ty1.IsKey()) { var lgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ilabel, cursor)); var pgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ipred, cursor)); uint ans = 0; uint pre = 0; while (cursor.MoveNext()) { lgetter(ref ans); pgetter(ref pre); var key = new Tuple <int, int>((int)pre, (int)ans); if (!conf.ContainsKey(key)) { conf[key] = 1; } else { ++conf[key]; } if (!dist1.ContainsKey((int)ans)) { dist1[(int)ans] = 1; } else { ++dist1[(int)ans]; } if (!dist2.ContainsKey((int)pre)) { dist2[(int)pre] = 1; } else { ++dist2[(int)pre]; } } } else { throw new NotImplementedException(string.Format("Not implemented for type {0}", ty1.ToString())); } #endregion } else if (kind == PredictionKind.BinaryClassification) { #region binary classification if (ty2.RawKind() != DataKind.Boolean) { throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema))); } if (ty1.RawKind() == DataKind.Single) { var lgetter = cursor.GetGetter <float>(SchemaHelper._dc(ilabel, cursor)); var pgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ipred, cursor)); float ans = 0; bool pre = default(bool); while (cursor.MoveNext()) { lgetter(ref ans); pgetter(ref pre); if (ans != 0 && ans != 1) { throw Contracts.Except("The problem is not binary, expected answer is {0}", ans); } var key = new Tuple <int, int>(pre ? 1 : 0, (int)ans); if (!conf.ContainsKey(key)) { conf[key] = 1; } else { ++conf[key]; } if (!dist1.ContainsKey((int)ans)) { dist1[(int)ans] = 1; } else { ++dist1[(int)ans]; } if (!dist2.ContainsKey(pre ? 1 : 0)) { dist2[pre ? 1 : 0] = 1; } else { ++dist2[pre ? 1 : 0]; } } } else if (ty1.RawKind() == DataKind.UInt32) { var lgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ilabel, cursor)); var pgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ipred, cursor)); uint ans = 0; bool pre = default(bool); while (cursor.MoveNext()) { lgetter(ref ans); pgetter(ref pre); if (ty1.IsKey()) { --ans; } if (ans != 0 && ans != 1) { throw Contracts.Except("The problem is not binary, expected answer is {0}", ans); } var key = new Tuple <int, int>(pre ? 1 : 0, (int)ans); if (!conf.ContainsKey(key)) { conf[key] = 1; } else { ++conf[key]; } if (!dist1.ContainsKey((int)ans)) { dist1[(int)ans] = 1; } else { ++dist1[(int)ans]; } if (!dist2.ContainsKey(pre ? 1 : 0)) { dist2[pre ? 1 : 0] = 1; } else { ++dist2[pre ? 1 : 0]; } } } else if (ty1.RawKind() == DataKind.Boolean) { var lgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ilabel, cursor)); var pgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ipred, cursor)); bool ans = default(bool); bool pre = default(bool); while (cursor.MoveNext()) { lgetter(ref ans); pgetter(ref pre); var key = new Tuple <int, int>(pre ? 1 : 0, ans ? 1 : 0); if (!conf.ContainsKey(key)) { conf[key] = 1; } else { ++conf[key]; } if (!dist1.ContainsKey(ans ? 1 : 0)) { dist1[ans ? 1 : 0] = 1; } else { ++dist1[ans ? 1 : 0]; } if (!dist2.ContainsKey(pre ? 1 : 0)) { dist2[pre ? 1 : 0] = 1; } else { ++dist2[pre ? 1 : 0]; } } } else { throw new NotImplementedException(string.Format("Not implemented for type {0}", ty1)); } #endregion } else if (kind == PredictionKind.Regression) { #region regression if (ty1.RawKind() != DataKind.Single) { throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema))); } if (ty2.RawKind() != DataKind.Single) { throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema))); } var lgetter = cursor.GetGetter <float>(SchemaHelper._dc(ilabel, cursor)); var pgetter = cursor.GetGetter <float>(SchemaHelper._dc(ipred, cursor)); float ans = 0; float pre = 0f; float error = 0f; while (cursor.MoveNext()) { lgetter(ref ans); pgetter(ref pre); error += (ans - pre) * (ans - pre); if (!dist1.ContainsKey((int)ans)) { dist1[(int)ans] = 1; } else { ++dist1[(int)ans]; } if (!dist2.ContainsKey((int)pre)) { dist2[(int)pre] = 1; } else { ++dist2[(int)pre]; } } if (float.IsNaN(error) || float.IsInfinity(error)) { throw new Exception("Regression wen wrong. Error is infinite."); } #endregion } else { throw new NotImplementedException(string.Format("Not implemented for kind {0}", kind)); } var nbError = conf.Where(c => c.Key.Item1 != c.Key.Item2).Select(c => c.Value).Sum(); var nbTotal = conf.Select(c => c.Value).Sum(); if (checkError && (nbError * 1.0 > nbTotal * ratio || dist2.Count <= 1)) { var sconf = string.Join("\n", conf.OrderBy(c => c.Key) .Select(c => string.Format("pred={0} exp={1} count={2}", c.Key.Item1, c.Key.Item2, c.Value))); var sdist2 = string.Join("\n", dist1.OrderBy(c => c.Key) .Select(c => string.Format("label={0} count={1}", c.Key, c.Value))); var sdist1 = string.Join("\n", dist2.OrderBy(c => c.Key).Take(20) .Select(c => string.Format("label={0} count={1}", c.Key, c.Value))); throw new Exception(string.Format("Too many errors {0}/{1}={7}\n###########\nConfusion:\n{2}\n########\nDIST1\n{3}\n###########\nDIST2\n{4}\nOutput:\n{5}\n...\n{6}", nbError, nbTotal, sconf, sdist1, sdist2, string.Join("\n", t1.Take(Math.Min(30, t1.Length))), string.Join("\n", t1.Skip(Math.Max(0, t1.Length - 30))), nbError * 1.0 / nbTotal)); } } #endregion }
/// <summary> /// Finalize the test on a transform, calls the transform, /// saves the data, saves the models, loads it back, saves the data again, /// checks the output is the same. /// </summary> /// <param name="env">environment</param> /// <param name="outModelFilePath"model filename</param> /// <param name="transform">transform to test</param> /// <param name="source">source (view before applying the transform</param> /// <param name="outData">fist data file</param> /// <param name="outData2">second data file</param> /// <param name="startsWith">Check that outputs is the same on disk after outputting the transformed data after the model was serialized</param> public static void SerializationTestTransform(IHostEnvironment env, string outModelFilePath, IDataTransform transform, IDataView source, string outData, string outData2, bool startsWith = false, bool skipDoubleQuote = false, bool forceDense = false) { // Saves model. var roles = env.CreateExamples(transform, null); using (var ch = env.Start("SaveModel")) using (var fs = File.Create(outModelFilePath)) TrainUtils.SaveModel(env, ch, fs, null, roles); if (!File.Exists(outModelFilePath)) { throw new FileNotFoundException(outModelFilePath); } // We load it again. using (var fs = File.OpenRead(outModelFilePath)) { var tr2 = env.LoadTransforms(fs, source); if (tr2 == null) { throw new Exception(string.Format("Unable to load '{0}'", outModelFilePath)); } if (transform.GetType() != tr2.GetType()) { throw new Exception(string.Format("Type mismatch {0} != {1}", transform.GetType(), tr2.GetType())); } } // Checks the outputs. var saver = env.CreateSaver(forceDense ? "Text{dense=+}" : "Text"); var columns = new int[transform.Schema.Count]; for (int i = 0; i < columns.Length; ++i) { columns[i] = i; } using (var fs2 = File.Create(outData)) saver.SaveData(fs2, transform, columns); if (!File.Exists(outModelFilePath)) { throw new FileNotFoundException(outData); } // Check we have the same output. using (var fs = File.OpenRead(outModelFilePath)) { var tr = env.LoadTransforms(fs, source); saver = env.CreateSaver(forceDense ? "Text{dense=+}" : "Text"); using (var fs2 = File.Create(outData2)) saver.SaveData(fs2, tr, columns); } var t1 = File.ReadAllLines(outData); var t2 = File.ReadAllLines(outData2); if (t1.Length != t2.Length) { throw new Exception(string.Format("Not the same number of lines: {0} != {1}", t1.Length, t2.Length)); } for (int i = 0; i < t1.Length; ++i) { if (skipDoubleQuote && (t1[i].Contains("\"\"\t\"\"") || t2[i].Contains("\"\"\t\"\""))) { continue; } if ((startsWith && !t1[i].StartsWith(t2[i])) || (!startsWith && t1[i] != t2[i])) { if (t1[i].EndsWith("\t5\t0:\"\"")) { var a = t1[i].Substring(0, t1[i].Length - "\t5\t0:\"\"".Length); a += "\t\"\"\t\"\"\t\"\"\t\"\"\t\"\""; var b = t2[i]; if ((startsWith && !a.StartsWith(b)) || (!startsWith && a != b)) { throw new Exception(string.Format("2-Mismatch on line {0}/{3}:\n{1}\n{2}", i, t1[i], t2[i], t1.Length)); } } else { // The test might fail because one side is dense and the other is sparse. throw new Exception(string.Format("3-Mismatch on line {0}/{3}:\n{1}\n{2}", i, t1[i], t2[i], t1.Length)); } } } }