public AnnotationRow(DataViewSchema.Annotations annotations) { Contracts.AssertValue(annotations); _annotations = annotations; }
private static unsafe void SendViewToNative(IChannel ch, EnvironmentBlock *penv, IDataView view, Dictionary <string, ColumnMetadataInfo> infos = null) { Contracts.AssertValue(ch); Contracts.Assert(penv != null); Contracts.AssertValue(view); Contracts.AssertValueOrNull(infos); if (penv->dataSink == null) { // Environment doesn't want any data! return; } var dataSink = MarshalDelegate <DataSink>(penv->dataSink); var schema = view.Schema; var colIndices = new List <int>(); var kindList = new List <DataKind>(); var keyCardList = new List <int>(); var nameUtf8Bytes = new List <Byte>(); var nameIndices = new List <int>(); var expandCols = new HashSet <int>(); var allNames = new HashSet <string>(); for (int col = 0; col < schema.Count; col++) { if (schema[col].IsHidden) { continue; } var fullType = schema[col].Type; var itemType = fullType.ItemType; var name = schema[col].Name; DataKind kind = itemType.RawKind; int keyCard; if (fullType.ValueCount == 0) { throw ch.ExceptNotSupp("Column has variable length vector: " + name + ". Not supported in python. Drop column before sending to Python"); } if (itemType.IsKey) { // Key types are returned as their signed counterparts in Python, so that -1 can be the missing value. // For U1 and U2 kinds, we convert to a larger type to prevent overflow. For U4 and U8 kinds, we convert // to I4 if the key count is known (since KeyCount is an I4), and to I8 otherwise. switch (kind) { case DataKind.U1: kind = DataKind.I2; break; case DataKind.U2: kind = DataKind.I4; break; case DataKind.U4: // We convert known-cardinality U4 key types to I4. kind = itemType.KeyCount > 0 ? DataKind.I4 : DataKind.I8; break; case DataKind.U8: // We convert known-cardinality U8 key types to I4. kind = itemType.KeyCount > 0 ? DataKind.I4 : DataKind.I8; break; } keyCard = itemType.KeyCount; if (!schema[col].HasKeyValues(keyCard)) { keyCard = -1; } } else if (itemType.IsStandardScalar()) { switch (itemType.RawKind) { default: throw Contracts.Except("Data type {0} not handled", itemType.RawKind); case DataKind.I1: case DataKind.I2: case DataKind.I4: case DataKind.I8: case DataKind.U1: case DataKind.U2: case DataKind.U4: case DataKind.U8: case DataKind.R4: case DataKind.R8: case DataKind.BL: case DataKind.TX: break; } keyCard = -1; } else { throw Contracts.Except("Data type {0} not handled", itemType.RawKind); } int nSlots; ColumnMetadataInfo info; if (infos != null && infos.TryGetValue(name, out info) && info.Expand) { expandCols.Add(col); Contracts.Assert(fullType.IsKnownSizeVector); nSlots = fullType.VectorSize; if (info.SlotNames != null) { Contracts.Assert(info.SlotNames.Length == nSlots); for (int i = 0; i < nSlots; i++) { AddUniqueName(info.SlotNames[i], allNames, nameIndices, nameUtf8Bytes); } } else if (schema[col].HasSlotNames(nSlots)) { var romNames = default(VBuffer <ReadOnlyMemory <char> >); schema[col].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref romNames); foreach (var kvp in romNames.Items(true)) { // REVIEW: Add the proper number of zeros to the slot index to make them sort in the right order. var slotName = name + "." + (!kvp.Value.IsEmpty ? kvp.Value.ToString() : kvp.Key.ToString(CultureInfo.InvariantCulture)); AddUniqueName(slotName, allNames, nameIndices, nameUtf8Bytes); } } else { for (int i = 0; i < nSlots; i++) { AddUniqueName(name + "." + i, allNames, nameIndices, nameUtf8Bytes); } } } else { nSlots = 1; AddUniqueName(name, allNames, nameIndices, nameUtf8Bytes); } colIndices.Add(col); for (int i = 0; i < nSlots; i++) { kindList.Add(kind); keyCardList.Add(keyCard); } } ch.Assert(allNames.Count == kindList.Count); ch.Assert(allNames.Count == keyCardList.Count); ch.Assert(allNames.Count == nameIndices.Count); var kinds = kindList.ToArray(); var keyCards = keyCardList.ToArray(); var nameBytes = nameUtf8Bytes.ToArray(); var names = new byte *[allNames.Count]; fixed(DataKind *prgkind = kinds) fixed(byte *prgbNames = nameBytes) fixed(byte **prgname = names) fixed(int *prgkeyCard = keyCards) { for (int iid = 0; iid < names.Length; iid++) { names[iid] = prgbNames + nameIndices[iid]; } DataViewBlock block; block.ccol = allNames.Count; block.crow = view.GetRowCount() ?? 0; block.names = (sbyte **)prgname; block.kinds = prgkind; block.keyCards = prgkeyCard; dataSink(penv, &block, out var setters, out var keyValueSetter); if (setters == null) { // REVIEW: What should we do? return; } ch.Assert(keyValueSetter != null); var kvSet = MarshalDelegate <KeyValueSetter>(keyValueSetter); using (var cursor = view.GetRowCursor(colIndices.Contains)) { var fillers = new BufferFillerBase[colIndices.Count]; var pyColumn = 0; var keyIndex = 0; for (int i = 0; i < colIndices.Count; i++) { var type = schema[colIndices[i]].Type; if (type.ItemType.IsKey && schema[colIndices[i]].HasKeyValues(type.ItemType.KeyCount)) { ch.Assert(schema[colIndices[i]].HasKeyValues(type.ItemType.KeyCount)); var keyValues = default(VBuffer <ReadOnlyMemory <char> >); schema[colIndices[i]].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues); for (int slot = 0; slot < type.ValueCount; slot++) { foreach (var kvp in keyValues.Items()) { if (kvp.Value.IsEmpty) { kvSet(penv, keyIndex, kvp.Key, null, 0); } else { byte[] bt = Encoding.UTF8.GetBytes(kvp.Value.ToString()); fixed(byte *pt = bt) kvSet(penv, keyIndex, kvp.Key, (sbyte *)pt, bt.Length); } } keyIndex++; } } fillers[i] = BufferFillerBase.Create(penv, cursor, pyColumn, colIndices[i], kinds[pyColumn], type, setters[pyColumn]); pyColumn += type.IsVector ? type.VectorSize : 1; } for (int crow = 0; ; crow++) { // Advance to the next row. if (!cursor.MoveNext()) { break; } // Fill values for the current row. for (int i = 0; i < fillers.Length; i++) { fillers[i].Set(); } } } } }
protected override void CheckLabel(RoleMappedData data) { Contracts.AssertValue(data); data.CheckBinaryLabel(); }
public ListAggregator(Row row, int col) { Contracts.AssertValue(row); _srcGetter = row.GetGetter <TValue>(col); _getter = (ValueGetter <VBuffer <TValue> >)Getter; }
/// <summary> /// Assumes input is sorted and finds value using BinarySearch. /// If value is not found, returns the logical index of 'value' in the sorted list i.e index of the first element greater than value. /// In case of duplicates it returns the index of the first one. /// It guarantees that items before the returned index are < value, while those at and after the returned index are >= value. /// </summary> public static int FindIndexSorted(this Double[] input, Double value) { Contracts.AssertValue(input); return(FindIndexSorted(input, 0, input.Length, value)); }
public override void Visit(NameNode node) { Contracts.AssertValue(node); _wrt.Write(node.Value); }
public bool TryUnparse(StringBuilder sb) { Contracts.AssertValue(sb); return(TryUnparseCore(sb)); }
public SchemaImpl(TransposeLoader parent) { Contracts.AssertValue(parent); _parent = parent; var view = parent._schemaEntry.GetView().Schema; }
public Bindings(ModelLoadContext ctx, DatabaseLoader parent) { Contracts.AssertValue(ctx); // *** Binary format *** // int: number of columns // foreach column: // int: id of column name // byte: DataKind // byte: bool of whether this is a key type // for a key type: // ulong: count for key range // int: number of segments // foreach segment: // string id: name // int: min // int: lim // byte: force vector (verWrittenCur: verIsVectorSupported) int cinfo = ctx.Reader.ReadInt32(); Contracts.CheckDecode(cinfo > 0); Infos = new ColInfo[cinfo]; for (int iinfo = 0; iinfo < cinfo; iinfo++) { string name = ctx.LoadNonEmptyString(); PrimitiveDataViewType itemType; var kind = (InternalDataKind)ctx.Reader.ReadByte(); Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind)); bool isKey = ctx.Reader.ReadBoolByte(); if (isKey) { ulong count; Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType())); count = ctx.Reader.ReadUInt64(); Contracts.CheckDecode(0 < count); itemType = new KeyDataViewType(kind.ToType(), count); } else { itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } int cseg = ctx.Reader.ReadInt32(); Segment[] segs; if (cseg == 0) { segs = null; } else { Contracts.CheckDecode(cseg > 0); segs = new Segment[cseg]; for (int iseg = 0; iseg < cseg; iseg++) { string columnName = ctx.LoadStringOrNull(); int min = ctx.Reader.ReadInt32(); int lim = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= min && min < lim); bool forceVector = ctx.Reader.ReadBoolByte(); segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector); } } // Note that this will throw if the segments are ill-structured, including the case // of multiple variable segments (since those segments will overlap and overlapping // segments are illegal). Infos[iinfo] = ColInfo.Create(name, itemType, segs, false); } OutputSchema = ComputeOutputSchema(); }
private static bool Compat(AlignedArray a) { Contracts.AssertValue(a); Contracts.Assert(a.Size > 0); return(a.CbAlign == CbAlign); }
/// <summary> /// Constructs an empty table of contents entry, with no offset. /// </summary> private SubIdvEntry(TransposeLoader parent) { Contracts.AssertValue(parent); _parent = parent; }
/// <summary> /// Hash the characters in a string builder. This MUST produce the same result /// as HashString(sb.ToString()). /// </summary> public static uint HashString(StringBuilder sb) { Contracts.AssertValue(sb); return(MurmurHash((5381 << 16) + 5381, sb, 0, sb.Length)); }
/// <summary> /// Hash the characters in a string. This MUST produce the same result as the other /// overloads (with equivalent characters). /// </summary> public static uint HashString(string str) { Contracts.AssertValue(str); return(MurmurHash((5381 << 16) + 5381, str, 0, str.Length)); }
public WeightsCollection(LinearPredictor pred) { Contracts.AssertValue(pred); _pred = pred; }
public override void Visit(StrLitNode node) { Contracts.AssertValue(node); Show(node.Value); ShowType(node); }
protected Value(RowCursor cursor) { Contracts.AssertValue(cursor); Cursor = cursor; }
public override void Visit(NumLitNode node) { Contracts.AssertValue(node); _wrt.Write(node.Value.ToString()); ShowType(node); }
protected RowToRowTransformerBase(IHost host) { Contracts.AssertValue(host); Host = host; }
public override void Visit(ParamNode node) { Contracts.AssertValue(node); _wrt.Write(node.Name); ShowType(node); }
/// <summary> /// Copy the values of src vector into this vector. The src vector must have the same size as this vector. /// </summary> /// <param name="src">The source vector</param> public void CopyFrom(CpuAlignedVector src) { Contracts.AssertValue(src); Contracts.Assert(src._size == _size); _items.CopyFrom(src._items); }
public OutPipelineColumn(Scalar <string> path, string relativeTo) : base(new Reconciler(relativeTo), path) { Contracts.AssertValue(path); _input = path; }
public JObject GetJsonObject(object instance, Dictionary <string, List <ParameterBinding> > inputBindingMap, Dictionary <ParameterBinding, VariableBinding> inputMap) { Contracts.CheckValue(instance, nameof(instance)); Contracts.Check(instance.GetType() == _type); var result = new JObject(); var defaults = Activator.CreateInstance(_type); for (int i = 0; i < _fields.Length; i++) { var field = _fields[i]; var attr = _attrs[i]; var instanceVal = field.GetValue(instance); var defaultsVal = field.GetValue(defaults); if (inputBindingMap.TryGetValue(field.Name, out List <ParameterBinding> bindings)) { // Handle variables. Contracts.Assert(bindings.Count > 0); VariableBinding varBinding; var paramBinding = bindings[0]; if (paramBinding is SimpleParameterBinding) { Contracts.Assert(bindings.Count == 1); bool success = inputMap.TryGetValue(paramBinding, out varBinding); Contracts.Assert(success); Contracts.AssertValue(varBinding); result.Add(field.Name, new JValue(varBinding.ToJson())); } else if (paramBinding is ArrayIndexParameterBinding) { // Array parameter bindings. var array = new JArray(); foreach (var parameterBinding in bindings) { Contracts.Assert(parameterBinding is ArrayIndexParameterBinding); bool success = inputMap.TryGetValue(parameterBinding, out varBinding); Contracts.Assert(success); Contracts.AssertValue(varBinding); array.Add(new JValue(varBinding.ToJson())); } result.Add(field.Name, array); } else { // Dictionary parameter bindings. Not supported yet. Contracts.Assert(paramBinding is DictionaryKeyParameterBinding); throw Contracts.ExceptNotImpl("Dictionary of variables not yet implemented."); } } else if (instanceVal == null && defaultsVal != null) { // Handle null values. result.Add(field.Name, new JValue(instanceVal)); } else if (instanceVal != null && (attr.Input.IsRequired || !instanceVal.Equals(defaultsVal))) { // A required field will be serialized regardless of whether or not its value is identical to the default. var type = instanceVal.GetType(); if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional <>)) { var isExplicit = ExtractOptional(ref instanceVal, ref type); if (!isExplicit) { continue; } } if (type == typeof(JArray)) { result.Add(field.Name, (JArray)instanceVal); } else if (type.IsGenericType && ((type.GetGenericTypeDefinition() == typeof(Var <>)) || type.GetGenericTypeDefinition() == typeof(ArrayVar <>) || type.GetGenericTypeDefinition() == typeof(DictionaryVar <>))) { result.Add(field.Name, new JValue($"${((IVarSerializationHelper)instanceVal).VarName}")); } else if (type == typeof(bool) || type == typeof(string) || type == typeof(char) || type == typeof(double) || type == typeof(float) || type == typeof(int) || type == typeof(long) || type == typeof(uint) || type == typeof(ulong)) { // Handle simple types. result.Add(field.Name, new JValue(instanceVal)); } else if (type.IsEnum) { // Handle enums. result.Add(field.Name, new JValue(instanceVal.ToString())); } else if (type.IsArray) { // Handle arrays. var array = (Array)instanceVal; var jarray = new JArray(); var elementType = type.GetElementType(); if (elementType == typeof(bool) || elementType == typeof(string) || elementType == typeof(char) || elementType == typeof(double) || elementType == typeof(float) || elementType == typeof(int) || elementType == typeof(long) || elementType == typeof(uint) || elementType == typeof(ulong)) { foreach (object item in array) { jarray.Add(new JValue(item)); } } else { var builder = new InputBuilder(_ectx, elementType, _catalog); foreach (object item in array) { jarray.Add(builder.GetJsonObject(item, inputBindingMap, inputMap)); } } result.Add(field.Name, jarray); } else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary <,>) && type.GetGenericArguments()[0] == typeof(string)) { // Handle dictionaries. // REVIEW: Needs to be implemented when we will have entry point arguments that contain dictionaries. } else if (typeof(IComponentFactory).IsAssignableFrom(type)) { // Handle component factories. bool success = _catalog.TryFindComponent(type, out ComponentCatalog.ComponentInfo instanceInfo); Contracts.Assert(success); var builder = new InputBuilder(_ectx, type, _catalog); var instSettings = builder.GetJsonObject(instanceVal, inputBindingMap, inputMap); ComponentCatalog.ComponentInfo defaultInfo = null; JObject defSettings = new JObject(); if (defaultsVal != null) { var deftype = defaultsVal.GetType(); if (deftype.IsGenericType && deftype.GetGenericTypeDefinition() == typeof(Optional <>)) { ExtractOptional(ref defaultsVal, ref deftype); } success = _catalog.TryFindComponent(deftype, out defaultInfo); Contracts.Assert(success); builder = new InputBuilder(_ectx, deftype, _catalog); defSettings = builder.GetJsonObject(defaultsVal, inputBindingMap, inputMap); } if (instanceInfo.Name != defaultInfo?.Name || instSettings.ToString() != defSettings.ToString()) { var jcomponent = new JObject { { FieldNames.Name, new JValue(instanceInfo.Name) } }; if (instSettings.ToString() != defSettings.ToString()) { jcomponent.Add(FieldNames.Settings, instSettings); } result.Add(field.Name, jcomponent); } } else { // REVIEW: pass in the bindings once we support variables in inner fields. // Handle structs. var builder = new InputBuilder(_ectx, type, _catalog); result.Add(field.Name, builder.GetJsonObject(instanceVal, new Dictionary <string, List <ParameterBinding> >(), new Dictionary <ParameterBinding, VariableBinding>())); } } } return(result); }
/// <summary> /// Assumes input is sorted and finds value using BinarySearch. /// If value is not found, returns the logical index of 'value' in the sorted list i.e index of the first element greater than value. /// In case of duplicates it returns the index of the first one. /// It guarantees that items before the returned index are < value, while those at and after the returned index are >= value. /// </summary> public static int FindIndexSorted(this IList <float> input, float value) { Contracts.AssertValue(input); return(FindIndexSorted(input, 0, input.Count, value)); }
private static object ParseJsonValue(IExceptionContext ectx, Type type, Attributes attributes, JToken value, ComponentCatalog catalog) { Contracts.AssertValue(ectx); ectx.AssertValue(type); ectx.AssertValueOrNull(value); ectx.AssertValue(catalog); if (value == null) { return(null); } if (value is JValue val && val.Value == null) { return(null); } if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional <>) || type.GetGenericTypeDefinition() == typeof(Nullable <>))) { if (type.GetGenericTypeDefinition() == typeof(Optional <>) && value.HasValues) { value = value.Values().FirstOrDefault(); } type = type.GetGenericArguments()[0]; } if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Var <>))) { string varName = value.Value <string>(); ectx.Check(VariableBinding.IsBindingToken(value), "Variable name expected."); var variable = Activator.CreateInstance(type) as IVarSerializationHelper; var varBinding = VariableBinding.Create(ectx, varName); variable.VarName = varBinding.VariableName; return(variable); } if (type == typeof(JArray) && value is JArray) { return(value); } TlcModule.DataKind dt = TlcModule.GetDataType(type); try { switch (dt) { case TlcModule.DataKind.Bool: return(value.Value <bool>()); case TlcModule.DataKind.String: return(value.Value <string>()); case TlcModule.DataKind.Char: return(value.Value <char>()); case TlcModule.DataKind.Enum: if (!Enum.IsDefined(type, value.Value <string>())) { throw ectx.Except($"Requested value '{value.Value<string>()}' is not a member of the Enum type '{type.Name}'"); } return(Enum.Parse(type, value.Value <string>())); case TlcModule.DataKind.Float: if (type == typeof(double)) { return(value.Value <double>()); } else if (type == typeof(float)) { return(value.Value <float>()); } else { ectx.Assert(false); throw ectx.ExceptNotSupp(); } case TlcModule.DataKind.Array: var ja = value as JArray; ectx.Check(ja != null, "Expected array value"); Func <IExceptionContext, JArray, Attributes, ComponentCatalog, object> makeArray = MakeArray <int>; return(Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog)); case TlcModule.DataKind.Int: if (type == typeof(long)) { return(value.Value <long>()); } if (type == typeof(int)) { return(value.Value <int>()); } ectx.Assert(false); throw ectx.ExceptNotSupp(); case TlcModule.DataKind.UInt: if (type == typeof(ulong)) { return(value.Value <ulong>()); } if (type == typeof(uint)) { return(value.Value <uint>()); } ectx.Assert(false); throw ectx.ExceptNotSupp(); case TlcModule.DataKind.Dictionary: ectx.Check(value is JObject, "Expected object value"); Func <IExceptionContext, JObject, Attributes, ComponentCatalog, object> makeDict = MakeDictionary <int>; return(Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog)); case TlcModule.DataKind.Component: var jo = value as JObject; ectx.Check(jo != null, "Expected object value"); // REVIEW: consider accepting strings alone. var jName = jo[FieldNames.Name]; ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); var name = jName.Value <string>(); ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, "Expected '" + FieldNames.Settings + "' field to be an object"); return(GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog)); default: var settings = value as JObject; ectx.Check(settings != null, "Expected object value"); var inputBuilder = new InputBuilder(ectx, type, catalog); if (inputBuilder._fields.Length == 0) { throw ectx.Except($"Unsupported input type: {dt}"); } if (settings != null) { foreach (var pair in settings) { if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) { throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); } } } var missing = inputBuilder.GetMissingValues().ToArray(); if (missing.Length > 0) { throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); } return(inputBuilder.GetInstance()); } } catch (FormatException ex) { if (ex.IsMarked()) { throw; } throw ectx.Except(ex, $"Failed to parse JSON value '{value}' as {type}"); } }
private ParquetLoader(IHost host, ModelLoadContext ctx, IMultiStreamSource files) { Contracts.AssertValue(host); _host = host; _host.AssertValue(ctx); _host.AssertValue(files); // *** Binary format *** // int: cached chunk size // bool: TreatBigIntegersAsDates flag // Schema of the loader (0x00010002) _columnChunkReadSize = ctx.Reader.ReadInt32(); bool treatBigIntegersAsDates = ctx.Reader.ReadBoolean(); if (ctx.Header.ModelVerWritten >= 0x00010002) { // Load the schema byte[] buffer = null; if (!ctx.TryLoadBinaryStream(SchemaCtxName, r => buffer = r.ReadByteArray())) { throw _host.ExceptDecode(); } var strm = new MemoryStream(buffer, writable: false); var loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), strm); Schema = loader.Schema; } // Only load Parquest related data if a file is present. Otherwise, just the Schema is valid. if (files.Count > 0) { _parquetOptions = new ParquetOptions() { TreatByteArrayAsString = true, TreatBigIntegersAsDates = treatBigIntegersAsDates }; _parquetStream = OpenStream(files); DataSet schemaDataSet; try { // We only care about the schema so ignore the rows. ReaderOptions readerOptions = new ReaderOptions() { Count = 0, Offset = 0 }; schemaDataSet = ParquetReader.Read(_parquetStream, _parquetOptions, readerOptions); _rowCount = schemaDataSet.TotalRowCount; } catch (Exception ex) { throw new InvalidDataException("Cannot read Parquet file", ex); } _columnsLoaded = InitColumns(schemaDataSet); Schema = CreateSchema(_host, _columnsLoaded); } else if (Schema == null) { throw _host.Except("Parquet loader must be created with one file"); } }
private protected override void CheckLabel(RoleMappedData data) { Contracts.AssertValue(data); data.CheckRegressionLabel(); }
protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, Float loss, int numParams) { Contracts.AssertValue(ch); Contracts.AssertValue(cursorFactory); Contracts.Assert(NumGoodRows > 0); Contracts.Assert(WeightSum > 0); Contracts.Assert(BiasCount == 1); Contracts.Assert(loss >= 0); Contracts.Assert(numParams >= BiasCount); Contracts.Assert(CurrentWeights.IsDense); ch.Info("Model trained with {0} training examples.", NumGoodRows); // Compute deviance: start with loss function. Float deviance = (Float)(2 * loss * WeightSum); if (L2Weight > 0) { // Need to subtract L2 regularization loss. // The bias term is not regularized. var regLoss = VectorUtils.NormSquared(CurrentWeights.Values, 1, CurrentWeights.Length - 1) * L2Weight; deviance -= regLoss; } if (L1Weight > 0) { // Need to subtract L1 regularization loss. // The bias term is not regularized. Double regLoss = 0; VBufferUtils.ForEachDefined(ref CurrentWeights, (ind, value) => { if (ind >= BiasCount) { regLoss += Math.Abs(value); } }); deviance -= (Float)regLoss * L1Weight * 2; } ch.Info("Residual Deviance: \t{0} (on {1} degrees of freedom)", deviance, Math.Max(NumGoodRows - numParams, 0)); // Compute null deviance, i.e., the deviance of null hypothesis. // Cap the prior positive rate at 1e-15. Double priorPosRate = _posWeight / WeightSum; Contracts.Assert(0 <= priorPosRate && priorPosRate <= 1); Float nullDeviance = (priorPosRate <= 1e-15 || 1 - priorPosRate <= 1e-15) ? 0f : (Float)(2 * WeightSum * MathUtils.Entropy(priorPosRate, true)); ch.Info("Null Deviance: \t{0} (on {1} degrees of freedom)", nullDeviance, NumGoodRows - 1); // Compute AIC. ch.Info("AIC: \t{0}", 2 * numParams + deviance); // Show the coefficients statistics table. var featureColIdx = cursorFactory.Data.Schema.Feature.Index; var schema = cursorFactory.Data.Data.Schema; var featureLength = CurrentWeights.Length - BiasCount; var namesSpans = VBufferUtils.CreateEmpty <DvText>(featureLength); if (schema.HasSlotNames(featureColIdx, featureLength)) { schema.GetMetadata(MetadataUtils.Kinds.SlotNames, featureColIdx, ref namesSpans); } Host.Assert(namesSpans.Length == featureLength); // Inverse mapping of non-zero weight slots. Dictionary <int, int> weightIndicesInvMap = null; // Indices of bias and non-zero weight slots. int[] weightIndices = null; // Whether all weights are non-zero. bool denseWeight = numParams == CurrentWeights.Length; // Extract non-zero indices of weight. if (!denseWeight) { weightIndices = new int[numParams]; weightIndicesInvMap = new Dictionary <int, int>(numParams); weightIndices[0] = 0; weightIndicesInvMap[0] = 0; int j = 1; for (int i = 1; i < CurrentWeights.Length; i++) { if (CurrentWeights.Values[i] != 0) { weightIndices[j] = i; weightIndicesInvMap[i] = j++; } } Contracts.Assert(j == numParams); } // Compute the standard error of coefficients. long hessianDimension = (long)numParams * (numParams + 1) / 2; if (hessianDimension > int.MaxValue) { ch.Warning("The number of parameter is too large. Cannot hold the variance-covariance matrix in memory. " + "Skipping computation of standard errors and z-statistics of coefficients. Consider choosing a larger L1 regularizer" + "to reduce the number of parameters."); _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance); return; } // Building the variance-covariance matrix for parameters. // The layout of this algorithm is a packed row-major lower triangular matrix. // E.g., layout of indices for 4-by-4: // 0 // 1 2 // 3 4 5 // 6 7 8 9 var hessian = new Double[hessianDimension]; // Initialize diagonal elements with L2 regularizers except for the first entry (index 0) // Since bias is not regularized. if (L2Weight > 0) { // i is the array index of the diagonal entry at iRow-th row and iRow-th column. // iRow is one-based. int i = 0; for (int iRow = 2; iRow <= numParams; iRow++) { i += iRow; hessian[i] = L2Weight; } Contracts.Assert(i == hessian.Length - 1); } // Initialize the remaining entries. var bias = CurrentWeights.Values[0]; using (var cursor = cursorFactory.Create()) { while (cursor.MoveNext()) { var label = cursor.Label; var weight = cursor.Weight; var score = bias + VectorUtils.DotProductWithOffset(ref CurrentWeights, 1, ref cursor.Features); // Compute Bernoulli variance n_i * p_i * (1 - p_i) for the i-th training example. var variance = weight / (2 + 2 * Math.Cosh(score)); // Increment the first entry of hessian. hessian[0] += variance; var values = cursor.Features.Values; if (cursor.Features.IsDense) { int ioff = 1; // Increment remaining entries of hessian. for (int i = 1; i < numParams; i++) { ch.Assert(ioff == i * (i + 1) / 2); int wi = weightIndices == null ? i - 1 : weightIndices[i] - 1; Contracts.Assert(0 <= wi && wi < cursor.Features.Length); var val = values[wi] * variance; // Add the implicit first bias term to X'X hessian[ioff++] += val; // Add the remainder of X'X for (int j = 0; j < i; j++) { int wj = weightIndices == null ? j : weightIndices[j + 1] - 1; Contracts.Assert(0 <= wj && wj < cursor.Features.Length); hessian[ioff++] += val * values[wj]; } } ch.Assert(ioff == hessian.Length); } else { var indices = cursor.Features.Indices; for (int ii = 0; ii < cursor.Features.Count; ++ii) { int i = indices[ii]; int wi = i + 1; if (weightIndicesInvMap != null && !weightIndicesInvMap.TryGetValue(i + 1, out wi)) { continue; } Contracts.Assert(0 < wi && wi <= cursor.Features.Length); int ioff = wi * (wi + 1) / 2; var val = values[ii] * variance; // Add the implicit first bias term to X'X hessian[ioff] += val; // Add the remainder of X'X for (int jj = 0; jj <= ii; jj++) { int j = indices[jj]; int wj = j + 1; if (weightIndicesInvMap != null && !weightIndicesInvMap.TryGetValue(j + 1, out wj)) { continue; } Contracts.Assert(0 < wj && wj <= cursor.Features.Length); hessian[ioff + wj] += val * values[jj]; } } } } } // Apply Cholesky Decomposition to find the inverse of the Hessian. Double[] invHessian = null; try { // First, find the Cholesky decomposition LL' of the Hessian. Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numParams, hessian); // Note that hessian is already modified at this point. It is no longer the original Hessian, // but instead represents the Cholesky decomposition L. // Also note that the following routine is supposed to consume the Cholesky decomposition L instead // of the original information matrix. Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numParams, hessian); // At this point, hessian should contain the inverse of the original Hessian matrix. // Swap hessian with invHessian to avoid confusion in the following context. Utils.Swap(ref hessian, ref invHessian); Contracts.Assert(hessian == null); } catch (DllNotFoundException) { throw ch.ExceptNotSupp("The MKL library (Microsoft.ML.MklImports.dll) or one of its dependencies is missing."); } Float[] stdErrorValues = new Float[numParams]; stdErrorValues[0] = (Float)Math.Sqrt(invHessian[0]); for (int i = 1; i < numParams; i++) { // Initialize with inverse Hessian. stdErrorValues[i] = (Single)invHessian[i * (i + 1) / 2 + i]; } if (L2Weight > 0) { // Iterate through all entries of inverse Hessian to make adjustment to variance. // A discussion on ridge regularized LR coefficient covariance matrix can be found here: // http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3228544/ // http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf int ioffset = 1; for (int iRow = 1; iRow < numParams; iRow++) { for (int iCol = 0; iCol <= iRow; iCol++) { var entry = (Single)invHessian[ioffset]; var adjustment = -L2Weight * entry * entry; stdErrorValues[iRow] -= adjustment; if (0 < iCol && iCol < iRow) { stdErrorValues[iCol] -= adjustment; } ioffset++; } } Contracts.Assert(ioffset == invHessian.Length); } for (int i = 1; i < numParams; i++) { stdErrorValues[i] = (Float)Math.Sqrt(stdErrorValues[i]); } VBuffer <Float> stdErrors = new VBuffer <Float>(CurrentWeights.Length, numParams, stdErrorValues, weightIndices); _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, ref stdErrors); }
public override void Visit(BoolLitNode node) { Contracts.AssertValue(node); _wrt.Write(node.Value ? "true" : "false"); ShowType(node); }
private protected WrappingRow(Row input) { Contracts.AssertValue(input); Input = input; }
public WeightsCollection(LinearModelParameters pred) { Contracts.AssertValue(pred); _pred = pred; }