private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            /* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise,
             * the predicted label is based on the sign of the score.
             */
            string   opType = "Binarizer";
            OnnxNode node;
            var      binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);

            if (Bindings.InfoCount >= 3)
            {
                Host.Assert(ctx.ContainsColumn(outColumnNames[2]));
                node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.5);
            }
            else
            {
                node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.0);
            }
            opType = "Cast";
            node   = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();

            node.AddAttribute("to", t);
        }
Ejemplo n.º 2
0
        internal static DatabaseLoader CreateDatabaseLoader <TInput>(IHostEnvironment host)
        {
            var userType = typeof(TInput);

            var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);

            var propertyInfos =
                userType
                .GetProperties(BindingFlags.Public | BindingFlags.Instance)
                .Where(x => x.CanRead && x.GetGetMethod() != null && x.GetIndexParameters().Length == 0);

            var memberInfos = (fieldInfos as IEnumerable <MemberInfo>).Concat(propertyInfos).ToArray();

            if (memberInfos.Length == 0)
            {
                throw host.ExceptParam(nameof(TInput), $"Should define at least one public, readable field or property in {nameof(TInput)}.");
            }

            var columns = new List <Column>();

            for (int index = 0; index < memberInfos.Length; index++)
            {
                var memberInfo      = memberInfos[index];
                var mappingAttrName = memberInfo.GetCustomAttribute <ColumnNameAttribute>();

                var column = new Column();
                column.Name = mappingAttrName?.Name ?? memberInfo.Name;

                var indexMappingAttr = memberInfo.GetCustomAttribute <LoadColumnAttribute>();
                var nameMappingAttr  = memberInfo.GetCustomAttribute <LoadColumnNameAttribute>();

                if (indexMappingAttr is object)
                {
                    if (nameMappingAttr is object)
                    {
                        throw Contracts.Except($"Cannot specify both {nameof(LoadColumnAttribute)} and {nameof(LoadColumnNameAttribute)}");
                    }

                    column.Source = indexMappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
                }
                else if (nameMappingAttr is object)
                {
                    column.Source = nameMappingAttr.Sources.Select((source) => new Range(source)).ToArray();
                }

                InternalDataKind dk;
                switch (memberInfo)
                {
                case FieldInfo field:
                    if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
                    {
                        throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");
                    }

                    break;

                case PropertyInfo property:
                    if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
                    {
                        throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
                    }
                    break;

                default:
                    Contracts.Assert(false);
                    throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
                }

                column.Type = dk.ToDbType();

                columns.Add(column);
            }

            var options = new Options
            {
                Columns = columns.ToArray()
            };

            return(new DatabaseLoader(host, options));
        }