Esempio n. 1
0
            public ColInfoEx(IExceptionContext ectx, ModelLoadContext ctx)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(ctx);

                // *** Binary format ***
                // int NumTopic;
                // Single AlphaSum;
                // Single Beta;
                // int MHStep;
                // int NumIter;
                // int LikelihoodInterval;
                // int NumThread;
                // int NumMaxDocToken;
                // int NumSummaryTermPerTopic;
                // int NumBurninIter;
                // byte ResetRandomGenerator;

                NumTopic = ctx.Reader.ReadInt32();
                ectx.CheckDecode(NumTopic > 0);

                AlphaSum = ctx.Reader.ReadSingle();

                Beta = ctx.Reader.ReadSingle();

                MHStep = ctx.Reader.ReadInt32();
                ectx.CheckDecode(MHStep > 0);

                NumIter = ctx.Reader.ReadInt32();
                ectx.CheckDecode(NumIter > 0);

                LikelihoodInterval = ctx.Reader.ReadInt32();
                ectx.CheckDecode(LikelihoodInterval > 0);

                NumThread = ctx.Reader.ReadInt32();
                ectx.CheckDecode(NumThread >= 0);

                NumMaxDocToken = ctx.Reader.ReadInt32();
                ectx.CheckDecode(NumMaxDocToken > 0);

                NumSummaryTermPerTopic = ctx.Reader.ReadInt32();
                ectx.CheckDecode(NumSummaryTermPerTopic > 0);

                NumBurninIter = ctx.Reader.ReadInt32();
                ectx.CheckDecode(NumBurninIter >= 0);

                ResetRandomGenerator = ctx.Reader.ReadBoolByte();
            }
            public void Save(ModelSaveContext ctx)
            {
                _ectx.AssertValue(ctx);

                // *** Binary format ***
                // int: ungroup mode
                // int: K - number of pivot columns
                // int[K]: ids of pivot column names

                ctx.Writer.Write((int)Mode);
                ctx.Writer.Write(_infos.Length);
                foreach (var ex in _infos)
                {
                    ctx.SaveNonEmptyString(ex.Name);
                }
            }
        private static byte[] ReadAllBytes(IExceptionContext ectx, BinaryReader rdr)
        {
            Contracts.AssertValue(ectx);
            ectx.AssertValue(rdr);
            ectx.Assert(rdr.BaseStream.CanSeek);

            long size = rdr.BaseStream.Length;

            ectx.CheckDecode(size <= int.MaxValue);

            var rgb = new byte[(int)size];
            int cb  = rdr.Read(rgb, 0, rgb.Length);

            ectx.CheckDecode(cb == rgb.Length);

            return(rgb);
        }
        private static ValueMap Train(IExceptionContext ectx, BinaryLoader ldr)
        {
            Contracts.AssertValue(ectx);
            ectx.AssertValue(ldr);
            ectx.Assert(ldr.Schema.ColumnCount == 2);

            // REVIEW: Should we allow term to be a vector of text (each term in the vector
            // would map to the same value)?
            ectx.Assert(ldr.Schema.GetColumnType(0).IsText);

            var schema    = ldr.Schema;
            var typeValue = schema.GetColumnType(1);

            // REVIEW: We should know the number of rows - use that info to set initial capacity.
            var values = ValueMap.Create(typeValue);

            using (var cursor = ldr.GetRowCursor(c => true))
                values.Train(ectx, cursor, 0, 1);
            return(values);
        }
            public SchemaImpl(IExceptionContext ectx, ISchema inputSchema, UngroupMode mode, string[] pivotColumns)
            {
                Contracts.AssertValueOrNull(ectx);
                _ectx = ectx;
                _ectx.AssertValue(inputSchema);
                _ectx.AssertNonEmpty(pivotColumns);

                _inputSchema = inputSchema;
                Mode         = mode;

                CheckAndBind(_ectx, inputSchema, pivotColumns, out _infos);

                _pivotColMap = new Dictionary <string, int>();
                _pivotIndex  = Utils.CreateArray(_inputSchema.ColumnCount, -1);
                for (int i = 0; i < _infos.Length; i++)
                {
                    var info = _infos[i];
                    _pivotColMap[info.Name] = info.Index;
                    _ectx.Assert(_pivotIndex[info.Index] == -1);
                    _pivotIndex[info.Index] = i;
                }
            }
Esempio n. 6
0
            public LdaState(IExceptionContext ectx, ColInfoEx ex, int numVocab)
                : this()
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(ex, "ex");

                ectx.Assert(numVocab >= 0);
                InfoEx    = ex;
                _numVocab = numVocab;

                _ldaTrainer = new LdaSingleBox(
                    InfoEx.NumTopic,
                    numVocab, /* Need to set number of vocabulary here */
                    InfoEx.AlphaSum,
                    InfoEx.Beta,
                    InfoEx.NumIter,
                    InfoEx.LikelihoodInterval,
                    InfoEx.NumThread,
                    InfoEx.MHStep,
                    InfoEx.NumSummaryTermPerTopic,
                    false,
                    InfoEx.NumMaxDocToken);
            }
            private static void CheckAndBind(IExceptionContext ectx, Schema inputSchema,
                string[] pivotColumns, out PivotColumnInfo[] infos)
            {
                Contracts.AssertValueOrNull(ectx);
                ectx.AssertValue(inputSchema);
                ectx.AssertNonEmpty(pivotColumns);

                infos = new PivotColumnInfo[pivotColumns.Length];
                for (int i = 0; i < pivotColumns.Length; i++)
                {
                    var name = pivotColumns[i];
                    // REVIEW: replace Check with CheckUser, once existing CheckUser is renamed to CheckUserArg or something.
                    ectx.CheckUserArg(!string.IsNullOrEmpty(name), nameof(Arguments.Column), "Column name cannot be empty");
                    int col;
                    if (!inputSchema.TryGetColumnIndex(name, out col))
                        throw ectx.ExceptUserArg(nameof(Arguments.Column), "Pivot column '{0}' is not found", name);
                    var colType = inputSchema.GetColumnType(col);
                    if (!colType.IsVector || !colType.ItemType.IsPrimitive)
                        throw ectx.ExceptUserArg(nameof(Arguments.Column),
                            "Pivot column '{0}' has type '{1}', but must be a vector of primitive types", name, colType);
                    infos[i] = new PivotColumnInfo(name, col, colType.VectorSize, (PrimitiveType)colType.ItemType);
                }
            }
Esempio n. 8
0
        /// <summary>
        /// Checks whether this object is consistent with an actual schema shape from a dynamic object,
        /// throwing exceptions if not.
        /// </summary>
        /// <param name="ectx">The context on which to throw exceptions</param>
        /// <param name="shape">The schema shape to check</param>
        public void Check(IExceptionContext ectx, SchemaShape shape)
        {
            Contracts.AssertValue(ectx);
            ectx.AssertValue(shape);

            foreach (var pair in Pairs)
            {
                var col = shape.FindColumn(pair.Key);
                if (col == null)
                {
                    throw ectx.ExceptParam(nameof(shape), $"Column named '{pair.Key}' was not found");
                }
                var type = GetTypeOrNull(col);
                if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value)))
                {
                    // When not null, we can use IsAssignableFrom to indicate we could assign to this, so as to allow
                    // for example Key<uint, string> to be considered to be compatible with Key<uint>.

                    // In the null case, while we cannot directly verify an unrecognized type, we can at least verify
                    // that the statically declared type should not have corresponded to a recognized type.
                    if (!pair.Value.IsAssignableFromStaticPipeline(type))
                    {
                        // This is generally an error, unless it's the situation where the asserted type is Key<,> but we could
                        // only resolve it so far as Key<>, since for the moment the SchemaShape cannot determine the type of key
                        // value metadata. In which case, we can check if the declared type is a subtype of the key that was determined
                        // from the analysis.
                        if (pair.Value.IsGenericType && pair.Value.GetGenericTypeDefinition() == typeof(Key <,>) &&
                            type.IsAssignableFromStaticPipeline(pair.Value))
                        {
                            continue;
                        }
                        throw ectx.ExceptParam(nameof(shape),
                                               $"Column '{pair.Key}' of type '{col.GetTypeString()}' cannot be expressed statically as type '{pair.Value}'.");
                    }
                }
            }
        }
            public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner,
                               RoleMappedSchema schema)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(owner);
                ectx.AssertValue(schema);
                ectx.Assert(schema.Feature.HasValue);

                _ectx = ectx;

                _owner = owner;
                InputRoleMappedSchema = schema;

                // A vector containing the output of each tree on a given example.
                var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees);
                // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
                // ends up in all the trees in the ensemble.
                var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount);
                // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
                // the paths of the example in all the trees in the ensemble.
                // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
                // and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
                // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
                // which means that #internal = #leaf - 1.
                // Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
                var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);

                // Start creating output schema with types derived above.
                var schemaBuilder = new SchemaBuilder();

                // Metadata of tree values.
                var treeIdMetadataBuilder = new MetadataBuilder();

                treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetTreeSlotNames);
                // Add the column of trees' output values
                schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata());

                // Metadata of leaf IDs.
                var leafIdMetadataBuilder = new MetadataBuilder();

                leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetLeafSlotNames);
                leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
                // Add the column of leaves' IDs where the input example reaches.
                schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata());

                // Metadata of path IDs.
                var pathIdMetadataBuilder = new MetadataBuilder();

                pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetPathSlotNames);
                pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
                // Add the column of encoded paths which the input example passes.
                schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata());

                OutputSchema = schemaBuilder.GetSchema();

                // Tree values must be the first output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
                // leaf IDs must be the second output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
                // Path IDs must be the third output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId);
            }
Esempio n. 10
0
        private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyCount range,
                                        out PrimitiveDataViewType itemType, out ColInfoEx ex)
        {
            ectx.AssertValue(info);
            ectx.Assert(Enum.IsDefined(typeof(DataKind), kind));

            ex = null;

            var typeSrc = info.TypeSrc;

            if (range != null)
            {
                itemType = TypeParsingUtils.ConstructKeyType(SchemaHelper.DataKind2InternalDataKind(kind), range);
            }
            else if (!typeSrc.ItemType().IsKey())
            {
                itemType = ColumnTypeHelper.PrimitiveFromKind(kind);
            }
            else if (!ColumnTypeHelper.IsValidDataKind(kind))
            {
                itemType = ColumnTypeHelper.PrimitiveFromKind(kind);
                return(false);
            }
            else
            {
                var key = typeSrc.ItemType().AsKey();
                ectx.Assert(ColumnTypeHelper.IsValidDataKind(key.RawKind()));
                ulong count = key.Count;
                // Technically, it's an error for the counts not to match, but we'll let the Conversions
                // code return false below. There's a possibility we'll change the standard conversions to
                // map out of bounds values to zero, in which case, this is the right thing to do.
                ulong max = (ulong)kind;
                if ((ulong)count > max)
                {
                    count = max;
                }
                itemType = new KeyDataViewType(SchemaHelper.DataKind2ColumnType(kind).RawType, count);
            }

            // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get
            // re-fetched by the utils code when needed.
            bool     identity;
            Delegate del;

            if (!Conversions.DefaultInstance.TryGetStandardConversion(typeSrc.ItemType(), itemType, out del, out identity))
            {
                if (typeSrc.ItemType().RawKind() == itemType.RawKind())
                {
                    switch (typeSrc.ItemType().RawKind())
                    {
                    case DataKind.UInt32:
                        // Key starts at 1.
                        // Multiclass future issue
                        uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0);
                        identity = false;
                        ValueMapper <uint, uint> map_ = (in uint src, ref uint dst) => { dst = src + plus; };
                        del = (Delegate)map_;
                        if (del == null)
                        {
                            throw Contracts.ExceptNotSupp("Issue with casting");
                        }
                        break;

                    default:
                        throw Contracts.Except("Not suppoted type {0}", typeSrc.ItemType().RawKind());
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Int64 && kind == DataKind.UInt64)
                {
                    ulong plus = (itemType.IsKey() ? (ulong)1 : (ulong)0) - (typeSrc.IsKey() ? (ulong)1 : (ulong)0);
                    identity = false;
                    ValueMapper <long, ulong> map_ = (in long src, ref ulong dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (ulong)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.UInt64)
                {
                    ulong plus = (itemType.IsKey() ? (ulong)1 : (ulong)0) - (typeSrc.IsKey() ? (ulong)1 : (ulong)0);
                    identity = false;
                    ValueMapper <float, ulong> map_ = (in float src, ref ulong dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (ulong)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Int64 && kind == DataKind.UInt32)
                {
                    // Multiclass future issue
                    uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0);
                    identity = false;
                    ValueMapper <long, uint> map_ = (in long src, ref uint dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (uint)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.UInt32)
                {
                    // Multiclass future issue
                    uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0);
                    identity = false;
                    ValueMapper <float, uint> map_ = (in float src, ref uint dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (uint)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.String)
                {
                    // Multiclass future issue
                    identity = false;
                    ValueMapper <float, DvText> map_ = (in float src, ref DvText dst) =>
                    {
                        dst = new DvText(string.Format("{0}", (int)src));
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else
                {
                    return(false);
                }
            }

            DataViewType typeDst = itemType;

            if (typeSrc.IsVector())
            {
                typeDst = new VectorDataViewType(itemType, typeSrc.AsVector().Dimensions.ToArray());
            }

            // An output column is transposable iff the input column was transposable.
            VectorDataViewType slotType = null;

            if (info.SlotTypeSrc != null)
            {
                slotType = new VectorDataViewType(itemType, info.SlotTypeSrc.Dimensions.ToArray());
            }

            ex = new ColInfoEx(kind, range != null, typeDst, slotType);
            return(true);
        }
        private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyRange range, out PrimitiveType itemType, out ColInfoEx ex)
        {
            ectx.AssertValue(info);
            ectx.Assert(Enum.IsDefined(typeof(DataKind), kind));

            ex = null;

            var typeSrc = info.TypeSrc;

            if (range != null)
            {
                itemType = TypeParsingUtils.ConstructKeyType(kind, range);
                if (!typeSrc.ItemType.IsKey && !typeSrc.ItemType.IsText)
                {
                    return(false);
                }
            }
            else if (!typeSrc.ItemType.IsKey)
            {
                itemType = PrimitiveType.FromKind(kind);
            }
            else if (!KeyType.IsValidDataKind(kind))
            {
                itemType = PrimitiveType.FromKind(kind);
                return(false);
            }
            else
            {
                var key = typeSrc.ItemType.AsKey;
                ectx.Assert(KeyType.IsValidDataKind(key.RawKind));
                int count = key.Count;
                // Technically, it's an error for the counts not to match, but we'll let the Conversions
                // code return false below. There's a possibility we'll change the standard conversions to
                // map out of bounds values to zero, in which case, this is the right thing to do.
                ulong max = kind.ToMaxInt();
                if ((ulong)count > max)
                {
                    count = (int)max;
                }
                itemType = new KeyType(kind, key.Min, count, key.Contiguous);
            }

            // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get
            // re-fetched by the utils code when needed.
            bool     identity;
            Delegate del;

            if (!Conversions.Instance.TryGetStandardConversion(typeSrc.ItemType, itemType, out del, out identity))
            {
                return(false);
            }

            ColumnType typeDst = itemType;

            if (typeSrc.IsVector)
            {
                typeDst = new VectorType(itemType, typeSrc.AsVector);
            }

            // An output column is transposable iff the input column was transposable.
            VectorType slotType = null;

            if (info.SlotTypeSrc != null)
            {
                slotType = new VectorType(itemType, info.SlotTypeSrc);
            }

            ex = new ColInfoEx(kind, range != null, typeDst, slotType);
            return(true);
        }
Esempio n. 12
0
        private static object ParseJsonValue(IExceptionContext ectx, Type type, Attributes attributes, JToken value, ComponentCatalog catalog)
        {
            Contracts.AssertValue(ectx);
            ectx.AssertValue(type);
            ectx.AssertValueOrNull(value);
            ectx.AssertValue(catalog);

            if (value == null)
            {
                return(null);
            }

            if (value is JValue val && val.Value == null)
            {
                return(null);
            }

            if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional <>) || type.GetGenericTypeDefinition() == typeof(Nullable <>)))
            {
                if (type.GetGenericTypeDefinition() == typeof(Optional <>) && value.HasValues)
                {
                    value = value.Values().FirstOrDefault();
                }
                type = type.GetGenericArguments()[0];
            }

            if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Var <>)))
            {
                string varName = value.Value <string>();
                ectx.Check(VariableBinding.IsBindingToken(value), "Variable name expected.");
                var variable   = Activator.CreateInstance(type) as IVarSerializationHelper;
                var varBinding = VariableBinding.Create(ectx, varName);
                variable.VarName = varBinding.VariableName;
                return(variable);
            }

            if (type == typeof(JArray) && value is JArray)
            {
                return(value);
            }

            TlcModule.DataKind dt = TlcModule.GetDataType(type);

            try
            {
                switch (dt)
                {
                case TlcModule.DataKind.Bool:
                    return(value.Value <bool>());

                case TlcModule.DataKind.String:
                    return(value.Value <string>());

                case TlcModule.DataKind.Char:
                    return(value.Value <char>());

                case TlcModule.DataKind.Enum:
                    if (!Enum.IsDefined(type, value.Value <string>()))
                    {
                        throw ectx.Except($"Requested value '{value.Value<string>()}' is not a member of the Enum type '{type.Name}'");
                    }
                    return(Enum.Parse(type, value.Value <string>()));

                case TlcModule.DataKind.Float:
                    if (type == typeof(double))
                    {
                        return(value.Value <double>());
                    }
                    else if (type == typeof(float))
                    {
                        return(value.Value <float>());
                    }
                    else
                    {
                        ectx.Assert(false);
                        throw ectx.ExceptNotSupp();
                    }

                case TlcModule.DataKind.Array:
                    var ja = value as JArray;
                    ectx.Check(ja != null, "Expected array value");
                    Func <IExceptionContext, JArray, Attributes, ComponentCatalog, object> makeArray = MakeArray <int>;
                    return(Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog));

                case TlcModule.DataKind.Int:
                    if (type == typeof(long))
                    {
                        return(value.Value <long>());
                    }
                    if (type == typeof(int))
                    {
                        return(value.Value <int>());
                    }
                    ectx.Assert(false);
                    throw ectx.ExceptNotSupp();

                case TlcModule.DataKind.UInt:
                    if (type == typeof(ulong))
                    {
                        return(value.Value <ulong>());
                    }
                    if (type == typeof(uint))
                    {
                        return(value.Value <uint>());
                    }
                    ectx.Assert(false);
                    throw ectx.ExceptNotSupp();

                case TlcModule.DataKind.Dictionary:
                    ectx.Check(value is JObject, "Expected object value");
                    Func <IExceptionContext, JObject, Attributes, ComponentCatalog, object> makeDict = MakeDictionary <int>;
                    return(Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog));

                case TlcModule.DataKind.Component:
                    var jo = value as JObject;
                    ectx.Check(jo != null, "Expected object value");
                    // REVIEW: consider accepting strings alone.
                    var jName = jo[FieldNames.Name];
                    ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component.");
                    ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string.");
                    var name = jName.Value <string>();
                    ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject,
                               "Expected '" + FieldNames.Settings + "' field to be an object");
                    return(GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog));

                default:
                    var settings = value as JObject;
                    ectx.Check(settings != null, "Expected object value");
                    var inputBuilder = new InputBuilder(ectx, type, catalog);

                    if (inputBuilder._fields.Length == 0)
                    {
                        throw ectx.Except($"Unsupported input type: {dt}");
                    }

                    if (settings != null)
                    {
                        foreach (var pair in settings)
                        {
                            if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value))
                            {
                                throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'");
                            }
                        }
                    }

                    var missing = inputBuilder.GetMissingValues().ToArray();
                    if (missing.Length > 0)
                    {
                        throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}");
                    }
                    return(inputBuilder.GetInstance());
                }
            }
            catch (FormatException ex)
            {
                if (ex.IsMarked())
                {
                    throw;
                }
                throw ectx.Except(ex, $"Failed to parse JSON value '{value}' as {type}");
            }
        }
Esempio n. 13
0
        /// <summary>
        /// Build a token for component default value. This will look up the component in the catalog, and if it finds an entry, it will
        /// build a JSON structure that would be parsed into the default value.
        ///
        /// This is an inherently fragile setup in case when the factory is not trivial, but it will work well for 'property bag' factories
        /// that we are currently using.
        /// </summary>
        private static JToken BuildComponentToken(IExceptionContext ectx, IComponentFactory value, ModuleCatalog catalog)
        {
            Contracts.AssertValueOrNull(ectx);
            ectx.AssertValue(value);
            ectx.AssertValue(catalog);

            var type = value.GetType();

            ModuleCatalog.ComponentInfo componentInfo;
            if (!catalog.TryFindComponent(type, out componentInfo))
            {
                // The default component is not in the catalog. This is, technically, allowed, but it means that there's no JSON representation
                // for the default value. We will emit the one the won't parse back.
                return(new JValue("(custom component)"));
            }

            ectx.Assert(componentInfo.ArgumentType == type);

            // Try to invoke default ctor for the factory to obtain defaults.
            object defaults;

            try
            {
                defaults = Activator.CreateInstance(type);
            }
            catch (MissingMemberException ex)
            {
                // There was no default constructor found.
                // This should never happen, since ModuleCatalog would error out if there is no default ctor.
                ectx.Assert(false);
                throw ectx.Except(ex, "Couldn't find default constructor");
            }

            var jResult   = new JObject();
            var jSettings = new JObject();

            jResult[FieldNames.Name] = componentInfo.Name;

            // Iterate over all fields of the factory object, and compare the values with the defaults.
            // If the value differs, insert it into the settings object.
            bool anyValue = false;

            foreach (var fieldInfo in type.GetFields())
            {
                var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
                           as ArgumentAttribute;
                if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
                {
                    continue;
                }
                ectx.Assert(!fieldInfo.IsStatic && !fieldInfo.IsInitOnly && !fieldInfo.IsLiteral);

                bool   needValue   = false;
                object actualValue = fieldInfo.GetValue(value);
                if (attr.IsRequired)
                {
                    needValue = true;
                }
                else
                {
                    object defaultValue = fieldInfo.GetValue(defaults);
                    needValue = !Equals(actualValue, defaultValue);
                }
                if (!needValue)
                {
                    continue;
                }
                jSettings[attr.Name ?? fieldInfo.Name] = BuildValueToken(ectx, actualValue, fieldInfo.FieldType, catalog);
                anyValue = true;
            }

            if (anyValue)
            {
                jResult[FieldNames.Settings] = jSettings;
            }
            return(jResult);
        }
Esempio n. 14
0
        private static JToken BuildValueToken(IExceptionContext ectx, object value, Type valueType, ModuleCatalog catalog)
        {
            Contracts.AssertValueOrNull(ectx);
            ectx.AssertValueOrNull(value);
            ectx.AssertValue(valueType);
            ectx.AssertValue(catalog);

            if (value == null)
            {
                return(null);
            }

            // Dive inside Nullable.
            if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(Nullable <>))
            {
                valueType = valueType.GetGenericArguments()[0];
            }

            // Dive inside Optional.
            if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(Optional <>))
            {
                valueType = valueType.GetGenericArguments()[0];
                value     = ((Optional)value).GetValue();
            }

            var dataType = TlcModule.GetDataType(valueType);

            switch (dataType)
            {
            case TlcModule.DataKind.Bool:
            case TlcModule.DataKind.Int:
            case TlcModule.DataKind.UInt:
            case TlcModule.DataKind.Float:
            case TlcModule.DataKind.String:
                return(new JValue(value));

            case TlcModule.DataKind.Char:
                return(new JValue(value.ToString()));

            case TlcModule.DataKind.Array:
                var valArray = value as Array;
                var ja       = new JArray();
                foreach (var item in valArray)
                {
                    ja.Add(BuildValueToken(ectx, item, item.GetType(), catalog));
                }
                return(ja);

            case TlcModule.DataKind.Enum:
                return(value.ToString());

            case TlcModule.DataKind.Dictionary:
                // REVIEW: need to figure out how to represent these.
                throw ectx.ExceptNotSupp("Dictionary and component default values are not supported");

            case TlcModule.DataKind.Component:
                var factory = value as IComponentFactory;
                ectx.AssertValue(factory);
                return(BuildComponentToken(ectx, factory, catalog));

            default:
                throw ectx.ExceptNotSupp("Encountered a default value for unsupported type {0}", dataType);
            }
        }
Esempio n. 15
0
        private static JToken BuildTypeToken(IExceptionContext ectx, FieldInfo fieldInfo, Type type, ModuleCatalog catalog)
        {
            Contracts.AssertValueOrNull(ectx);
            ectx.AssertValue(type);
            ectx.AssertValue(catalog);

            // REVIEW: Allows newly introduced types to not break the manifest bulding process.
            // Where possible, these types should be replaced by component kinds.
            if (type == typeof(CommonInputs.IEvaluatorInput) ||
                type == typeof(CommonOutputs.IEvaluatorOutput))
            {
                var jo         = new JObject();
                var typeString = $"{type}".Replace("Microsoft.ML.Runtime.EntryPoints.", "");
                jo[FieldNames.Kind]     = "EntryPoint";
                jo[FieldNames.ItemType] = typeString;
                return(jo);
            }
            type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(type);

            // Dive inside Var.
            if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var <>))
            {
                type = type.GetGenericArguments()[0];
            }

            var typeEnum = TlcModule.GetDataType(type);

            switch (typeEnum)
            {
            case TlcModule.DataKind.Unknown:
                var jo = new JObject();
                if (type == typeof(JArray))
                {
                    jo[FieldNames.Kind]     = TlcModule.DataKind.Array.ToString();
                    jo[FieldNames.ItemType] = "Node";
                    return(jo);
                }
                if (type == typeof(JObject))
                {
                    return("Bindings");
                }
                var fields = BuildInputManifest(ectx, type, catalog);
                if (fields.Count == 0)
                {
                    throw ectx.Except("Unexpected parameter type: {0}", type);
                }
                jo[FieldNames.Kind]   = "Struct";
                jo[FieldNames.Fields] = fields;
                return(jo);

            case TlcModule.DataKind.Float:
            case TlcModule.DataKind.Int:
            case TlcModule.DataKind.UInt:
            case TlcModule.DataKind.Char:
            case TlcModule.DataKind.String:
            case TlcModule.DataKind.Bool:
            case TlcModule.DataKind.DataView:
            case TlcModule.DataKind.TransformModel:
            case TlcModule.DataKind.PredictorModel:
            case TlcModule.DataKind.FileHandle:
                return(typeEnum.ToString());

            case TlcModule.DataKind.Enum:
                jo = new JObject();
                jo[FieldNames.Kind] = typeEnum.ToString();
                var values = Enum.GetNames(type);
                jo[FieldNames.Values] = new JArray(values);
                return(jo);

            case TlcModule.DataKind.Array:
                jo = new JObject();
                jo[FieldNames.Kind]     = typeEnum.ToString();
                jo[FieldNames.ItemType] = BuildTypeToken(ectx, fieldInfo, type.GetElementType(), catalog);
                return(jo);

            case TlcModule.DataKind.Dictionary:
                jo = new JObject();
                jo[FieldNames.Kind]     = typeEnum.ToString();
                jo[FieldNames.ItemType] = BuildTypeToken(ectx, fieldInfo, type.GetGenericArguments()[1], catalog);
                return(jo);

            case TlcModule.DataKind.Component:
                string kind;
                if (!catalog.TryGetComponentKind(type, out kind))
                {
                    throw ectx.Except("Field '{0}' is a component of unknown kind", fieldInfo.Name);
                }

                jo = new JObject();
                jo[FieldNames.Kind]          = typeEnum.ToString();
                jo[FieldNames.ComponentKind] = kind;
                return(jo);

            case TlcModule.DataKind.State:
                jo = new JObject();
                var typeString = $"{type}".Replace("Microsoft.ML.Runtime.Interfaces.", "");
                jo[FieldNames.Kind]     = "C# Object";
                jo[FieldNames.ItemType] = typeString;
                return(jo);

            default:
                ectx.Assert(false);
                throw ectx.ExceptNotSupp();
            }
        }
Esempio n. 16
0
        private static JArray BuildInputManifest(IExceptionContext ectx, Type inputType, ModuleCatalog catalog)
        {
            Contracts.AssertValueOrNull(ectx);
            ectx.AssertValue(inputType);
            ectx.AssertValue(catalog);

            // Instantiate a value of the input, to pull defaults out of.
            var defaults = Activator.CreateInstance(inputType);

            var inputs = new List <KeyValuePair <Double, JObject> >();

            foreach (var fieldInfo in inputType.GetFields())
            {
                var inputAttr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute;
                if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
                {
                    continue;
                }
                var jo = new JObject();
                jo[FieldNames.Name] = inputAttr.Name ?? fieldInfo.Name;
                jo[FieldNames.Type] = BuildTypeToken(ectx, fieldInfo, fieldInfo.FieldType, catalog);
                jo[FieldNames.Desc] = inputAttr.HelpText;
                if (inputAttr.Aliases != null)
                {
                    jo[FieldNames.Aliases] = new JArray(inputAttr.Aliases);
                }

                jo[FieldNames.Required]   = inputAttr.IsRequired;
                jo[FieldNames.SortOrder]  = inputAttr.SortOrder;
                jo[FieldNames.IsNullable] = fieldInfo.FieldType.IsGenericType && (fieldInfo.FieldType.GetGenericTypeDefinition() == typeof(Nullable <>));

                var defaultValue = fieldInfo.GetValue(defaults);
                var dataType     = TlcModule.GetDataType(fieldInfo.FieldType);
                if (!inputAttr.IsRequired || (dataType != TlcModule.DataKind.Unknown && defaultValue != null))
                {
                    jo[FieldNames.Default] = BuildValueToken(ectx, defaultValue, fieldInfo.FieldType, catalog);
                }

                if (fieldInfo.FieldType.IsGenericType &&
                    fieldInfo.FieldType.GetGenericTypeDefinition() == typeof(Optional <>))
                {
                    var val = fieldInfo.GetValue(defaults) as Optional;
                    if (val == null && !inputAttr.IsRequired)
                    {
                        throw ectx.Except("Field '{0}' is an Optional<> type but is null by default, instead of set to a constructed implicit default.", fieldInfo.Name);
                    }
                    if (val != null && val.IsExplicit)
                    {
                        throw ectx.Except("Field '{0}' is an Optional<> type with a non-implicit default value.", fieldInfo.Name);
                    }
                }

                var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() as TlcModule.RangeAttribute;
                if (rangeAttr != null)
                {
                    if (!TlcModule.IsNumericKind(TlcModule.GetDataType(fieldInfo.FieldType)))
                    {
                        throw ectx.Except("Field '{0}' has a range but is of a non-numeric type.", fieldInfo.Name);
                    }

                    if (!rangeAttr.Type.Equals(fieldInfo.FieldType))
                    {
                        throw ectx.Except("Field '{0}' has a range attribute that uses a type which is not equal to the field's FieldType.", fieldInfo.Name);
                    }

                    var jRange = new JObject();
                    if (rangeAttr.Sup != null)
                    {
                        jRange[FieldNames.Range.Sup] = JToken.FromObject(rangeAttr.Sup);
                    }
                    if (rangeAttr.Inf != null)
                    {
                        jRange[FieldNames.Range.Inf] = JToken.FromObject(rangeAttr.Inf);
                    }
                    if (rangeAttr.Max != null)
                    {
                        jRange[FieldNames.Range.Max] = JToken.FromObject(rangeAttr.Max);
                    }
                    if (rangeAttr.Min != null)
                    {
                        jRange[FieldNames.Range.Min] = JToken.FromObject(rangeAttr.Min);
                    }
                    jo[FieldNames.Range.Type] = jRange;
                }

                // Handle deprecated/obsolete attributes, passing along the message to the manifest.
                if (fieldInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() is ObsoleteAttribute obsAttr)
                {
                    var jParam = new JObject
                    {
                        [FieldNames.Deprecated.Message] = JToken.FromObject(obsAttr.Message),
                    };
                    jo[FieldNames.Deprecated.ToString()] = jParam;
                }

                if (fieldInfo.GetCustomAttributes(typeof(TlcModule.SweepableLongParamAttribute), false).FirstOrDefault() is TlcModule.SweepableLongParamAttribute slpAttr)
                {
                    var jParam = new JObject
                    {
                        [FieldNames.SweepableLongParam.RangeType] = JToken.FromObject("Long"),
                        [FieldNames.SweepableLongParam.Min]       = JToken.FromObject(slpAttr.Min),
                        [FieldNames.SweepableLongParam.Max]       = JToken.FromObject(slpAttr.Max)
                    };
                    if (slpAttr.StepSize != null)
                    {
                        jParam[FieldNames.SweepableLongParam.StepSize] = JToken.FromObject(slpAttr.StepSize);
                    }
                    if (slpAttr.NumSteps != null)
                    {
                        jParam[FieldNames.SweepableLongParam.NumSteps] = JToken.FromObject(slpAttr.NumSteps);
                    }
                    if (slpAttr.IsLogScale)
                    {
                        jParam[FieldNames.SweepableLongParam.IsLogScale] = JToken.FromObject(true);
                    }
                    jo[FieldNames.SweepableLongParam.ToString()] = jParam;
                }

                if (fieldInfo.GetCustomAttributes(typeof(TlcModule.SweepableFloatParamAttribute), false).FirstOrDefault() is TlcModule.SweepableFloatParamAttribute sfpAttr)
                {
                    var jParam = new JObject
                    {
                        [FieldNames.SweepableFloatParam.RangeType] = JToken.FromObject("Float"),
                        [FieldNames.SweepableFloatParam.Min]       = JToken.FromObject(sfpAttr.Min),
                        [FieldNames.SweepableFloatParam.Max]       = JToken.FromObject(sfpAttr.Max)
                    };
                    if (sfpAttr.StepSize != null)
                    {
                        jParam[FieldNames.SweepableFloatParam.StepSize] = JToken.FromObject(sfpAttr.StepSize);
                    }
                    if (sfpAttr.NumSteps != null)
                    {
                        jParam[FieldNames.SweepableFloatParam.NumSteps] = JToken.FromObject(sfpAttr.NumSteps);
                    }
                    if (sfpAttr.IsLogScale)
                    {
                        jParam[FieldNames.SweepableFloatParam.IsLogScale] = JToken.FromObject(true);
                    }
                    jo[FieldNames.SweepableFloatParam.ToString()] = jParam;
                }

                if (fieldInfo.GetCustomAttributes(typeof(TlcModule.SweepableDiscreteParamAttribute), false).FirstOrDefault() is TlcModule.SweepableDiscreteParamAttribute sdpAttr)
                {
                    var jParam = new JObject
                    {
                        [FieldNames.SweepableDiscreteParam.RangeType] = JToken.FromObject("Discrete"),
                        [FieldNames.SweepableDiscreteParam.Options]   = JToken.FromObject(sdpAttr.Options)
                    };
                    jo[FieldNames.SweepableDiscreteParam.ToString()] = jParam;
                }

                inputs.Add(new KeyValuePair <Double, JObject>(inputAttr.SortOrder, jo));
            }
            return(new JArray(inputs.OrderBy(x => x.Key).Select(x => x.Value).ToArray()));
        }
Esempio n. 17
0
            public LdaState(IExceptionContext ectx, ModelLoadContext ctx)
                : this()
            {
                ectx.AssertValue(ctx);

                // *** Binary format ***
                // <ColInfoEx>
                // int: vocabnum
                // long: memblocksize
                // long: aliasMemBlockSize
                // (serializing term by term, for one term)
                // int: term_id, int: topic_num, KeyValuePair<int, int>[]: termTopicVector

                InfoEx = new ColInfoEx(ectx, ctx);

                _numVocab = ctx.Reader.ReadInt32();
                ectx.CheckDecode(_numVocab > 0);

                long memBlockSize = ctx.Reader.ReadInt64();

                ectx.CheckDecode(memBlockSize > 0);

                long aliasMemBlockSize = ctx.Reader.ReadInt64();

                ectx.CheckDecode(aliasMemBlockSize > 0);

                _ldaTrainer = new LdaSingleBox(
                    InfoEx.NumTopic,
                    _numVocab, /* Need to set number of vocabulary here */
                    InfoEx.AlphaSum,
                    InfoEx.Beta,
                    InfoEx.NumIter,
                    InfoEx.LikelihoodInterval,
                    InfoEx.NumThread,
                    InfoEx.MHStep,
                    InfoEx.NumSummaryTermPerTopic,
                    false,
                    InfoEx.NumMaxDocToken);

                _ldaTrainer.AllocateModelMemory(_numVocab, InfoEx.NumTopic, memBlockSize, aliasMemBlockSize);

                for (int i = 0; i < _numVocab; i++)
                {
                    int termID = ctx.Reader.ReadInt32();
                    ectx.CheckDecode(termID >= 0);
                    int termTopicNum = ctx.Reader.ReadInt32();
                    ectx.CheckDecode(termTopicNum >= 0);

                    int[] topicId   = new int[termTopicNum];
                    int[] topicProb = new int[termTopicNum];

                    for (int j = 0; j < termTopicNum; j++)
                    {
                        topicId[j]   = ctx.Reader.ReadInt32();
                        topicProb[j] = ctx.Reader.ReadInt32();
                    }

                    //set the topic into _ldaTrainer inner topic table
                    _ldaTrainer.SetModel(termID, topicId, topicProb, termTopicNum);
                }

                //do the preparation
                if (!_predictionPreparationDone)
                {
                    _ldaTrainer.InitializeBeforeTest();
                    _predictionPreparationDone = true;
                }
            }