private void CompressionWorker(BlockingCollection <Block> toCompress, BlockingCollection <Block> toWrite, int columns, OrderedWaiter waiter,
                                       ExceptionMarshaller exMarshaller)
        {
            Contracts.AssertValue(exMarshaller);
            try
            {
                _host.AssertValue(toCompress);
                _host.AssertValue(toWrite);
                _host.Assert(columns > 0);
                _host.Assert(_deterministicBlockOrder == (waiter != null));

                foreach (Block block in toCompress.GetConsumingEnumerable(exMarshaller.Token))
                {
                    MemoryStream compressed = _memPool.Get();
                    int          uncompLength;
                    using (Stream stream = _compression.CompressStream(compressed))
                    {
                        MemoryStream uncompressed = block.BlockData;
                        uncompLength = (int)uncompressed.Length;
                        ArraySegment <byte> buffer;
                        bool tmp = uncompressed.TryGetBuffer(out buffer);
                        Contracts.Assert(tmp);
                        stream.Write(buffer.Array, buffer.Offset, buffer.Count);
                        _memPool.Return(ref uncompressed);
                    }
                    if (_deterministicBlockOrder)
                    {
                        waiter.Wait((long)columns * block.BlockIndex + block.ColumnIndex, exMarshaller.Token);
                    }
                    toWrite.Add(new Block(compressed, block.ColumnIndex, block.BlockIndex, uncompLength), exMarshaller.Token);
                    if (_deterministicBlockOrder)
                    {
                        waiter.Increment();
                    }
                }
            }
            catch (Exception ex)
            {
                exMarshaller.Set("compressing", ex);
            }
        }
示例#2
0
        // This method is called if only a datafile is specified, without a loader/term and value columns.
        // It determines the type of the Value column and returns the appropriate TextLoader component factory.
        private static IComponentFactory <IMultiStreamSource, IDataLoader> GetLoaderFactory(string filename, bool keyValues, IHost host)
        {
            Contracts.AssertValue(host);

            // If the user specified non-key values, we define the value column to be numeric.
            if (!keyValues)
            {
                return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                           (env, files) => new TextLoader(
                               env,
                               new TextLoader.Arguments()
                {
                    Column = new[]
                    {
                        new TextLoader.Column("Term", DataKind.TX, 0),
                        new TextLoader.Column("Value", DataKind.Num, 1)
                    }
                },
                               files)));
            }

            // If the user specified key values, we scan the values to determine the range of the key type.
            ulong min = ulong.MaxValue;
            ulong max = ulong.MinValue;

            try
            {
                var  txtArgs = new TextLoader.Arguments();
                bool parsed  = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs);
                host.Assert(parsed);
                var txtLoader = new TextLoader(host, txtArgs, new MultiFileSource(filename));
                using (var cursor = txtLoader.GetRowCursor(c => true))
                {
                    var    getTerm = cursor.GetGetter <DvText>(0);
                    var    getVal  = cursor.GetGetter <DvText>(1);
                    DvText txt     = default(DvText);

                    using (var ch = host.Start("Creating Text Lookup Loader"))
                    {
                        long countNonKeys = 0;
                        while (cursor.MoveNext())
                        {
                            getVal(ref txt);
                            ulong res;
                            // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0,
                            // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for
                            // computing max and min.
                            if (Conversions.Instance.TryParseKey(ref txt, 1, ulong.MaxValue, out res))
                            {
                                if (res < min && res != 0)
                                {
                                    min = res;
                                }
                                if (res > max)
                                {
                                    max = res;
                                }
                            }
                            // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds,
                            // then the value is 0, and we update min accordingly.
                            else if (Conversions.Instance.TryParse(ref txt, out res))
                            {
                                ch.Assert(res == 0);
                                min = 0;
                            }
                            //If parsing as a ulong fails, we increment the counter for the non-key values.
                            else
                            {
                                var term = default(DvText);
                                getTerm(ref term);
                                if (countNonKeys < 5)
                                {
                                    ch.Warning("Term '{0}' in mapping file is mapped to non key value '{1}'", term, txt);
                                }
                                countNonKeys++;
                            }
                        }
                        if (countNonKeys > 0)
                        {
                            ch.Warning("Found {0} non key values in the file '{1}'", countNonKeys, filename);
                        }
                        if (min > max)
                        {
                            min = 0;
                            max = uint.MaxValue - 1;
                            ch.Warning("did not find any valid key values in the file '{0}'", filename);
                        }
                        else
                        {
                            ch.Info("Found key values in the range {0} to {1} in the file '{2}'", min, max, filename);
                        }
                        ch.Done();
                    }
                }
            }
            catch (Exception e)
            {
                throw host.Except(e, "Failed to parse the lookup file '{0}' in TermLookupTransform", filename);
            }

            TextLoader.Column valueColumn = new TextLoader.Column("Value", DataKind.U4, 1);
            if (max - min < (ulong)int.MaxValue)
            {
                valueColumn.KeyRange = new KeyRange(min, max);
            }
            else if (max - min < (ulong)uint.MaxValue)
            {
                valueColumn.KeyRange = new KeyRange(min);
            }
            else
            {
                valueColumn.Type     = DataKind.U8;
                valueColumn.KeyRange = new KeyRange(min);
            }

            return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                       (env, files) => new TextLoader(
                           env,
                           new TextLoader.Arguments()
            {
                Column = new[]
                {
                    new TextLoader.Column("Term", DataKind.TX, 0),
                    valueColumn
                }
            },
                           files)));
        }
示例#3
0
        public IPredictor CombineModels(IEnumerable <IPredictor> models)
        {
            _host.CheckValue(models, nameof(models));

            var  ensemble         = new InternalTreeEnsemble();
            int  modelCount       = 0;
            int  featureCount     = -1;
            bool binaryClassifier = false;

            foreach (var model in models)
            {
                modelCount++;

                var predictor = model;
                _host.CheckValue(predictor, nameof(models), "One of the models is null");

                var    calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
                double paramA     = 1;
                if (calibrated != null)
                {
                    _host.Check(calibrated.WeeklyTypedCalibrator is PlattCalibrator,
                                "Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
                }

                predictor = calibrated.WeeklyTypedSubModel;
                paramA    = -((PlattCalibrator)calibrated.WeeklyTypedCalibrator).Slope;

                var tree = predictor as TreeEnsembleModelParameters;

                if (tree == null)
                {
                    throw _host.Except("Model is not a tree ensemble");
                }
                foreach (var t in tree.TrainedEnsemble.Trees)
                {
                    var bytes    = new byte[t.SizeInBytes()];
                    int position = -1;
                    t.ToByteArray(bytes, ref position);
                    position = -1;
                    var tNew = new InternalRegressionTree(bytes, ref position);
                    if (paramA != 1)
                    {
                        for (int i = 0; i < tNew.NumLeaves; i++)
                        {
                            tNew.SetOutput(i, tNew.LeafValues[i] * paramA);
                        }
                    }
                    ensemble.AddTree(tNew);
                }

                if (modelCount == 1)
                {
                    binaryClassifier = calibrated != null;
                    featureCount     = tree.InputType.GetValueCount();
                }
                else
                {
                    _host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models");
                    _host.Check(featureCount == tree.InputType.GetValueCount(), "Found models with different number of features");
                }
            }

            var scale = 1 / (double)modelCount;

            foreach (var t in ensemble.Trees)
            {
                for (int i = 0; i < t.NumLeaves; i++)
                {
                    t.SetOutput(i, t.LeafValues[i] * scale);
                }
            }

            switch (_kind)
            {
            case PredictionKind.BinaryClassification:
                if (!binaryClassifier)
                {
                    return(new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null));
                }

                var cali          = new PlattCalibrator(_host, -1, 0);
                var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null);
                return(new FeatureWeightsCalibratedModelParameters <FastTreeBinaryModelParameters, PlattCalibrator>(_host, fastTreeModel, cali));

            case PredictionKind.Regression:
                return(new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null));

            case PredictionKind.Ranking:
                return(new FastTreeRankingModelParameters(_host, ensemble, featureCount, null));

            default:
                _host.Assert(false);
                throw _host.ExceptNotSupp();
            }
        }
 public int GetInputIndex(int outputIndex)
 {
     _host.Assert(0 <= outputIndex && outputIndex < OutputToInputMap.Length);
     return(OutputToInputMap[outputIndex]);
 }
示例#5
0
 protected virtual void FinishPassCore()
 {
     Host.Assert(PassNum < 1);
 }
            //private Delegate CreateGetter(SchemaProxy schema, int index, Delegate peek)
            private Delegate CreateGetter(ColumnType colType, InternalSchemaDefinition.Column column, Delegate peek)
            {
                var outputType  = column.OutputType;
                var genericType = outputType;
                Func <Delegate, Delegate> del;

                if (outputType.IsArray)
                {
                    Host.Assert(colType.IsVector);
                    // String[] -> ReadOnlyMemory<char>
                    if (outputType.GetElementType() == typeof(string))
                    {
                        Host.Assert(colType.ItemType.IsText);
                        return(CreateConvertingArrayGetterDelegate <string, ReadOnlyMemory <char> >(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory <char> .Empty));
                    }

                    // T[] -> VBuffer<T>
                    if (outputType.GetElementType().IsGenericType&& outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Host.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == colType.ItemType.RawType);
                    }
                    else
                    {
                        Host.Assert(outputType.GetElementType() == colType.ItemType.RawType);
                    }
                    del         = CreateDirectArrayGetterDelegate <int>;
                    genericType = outputType.GetElementType();
                }
                else if (colType.IsVector)
                {
                    // VBuffer<T> -> VBuffer<T>
                    // REVIEW: Do we care about accomodating VBuffer<string> -> ReadOnlyMemory<char>?
                    Host.Assert(outputType.IsGenericType);
                    Host.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer <>));
                    Host.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType);
                    del         = CreateDirectVBufferGetterDelegate <int>;
                    genericType = colType.ItemType.RawType;
                }
                else if (colType.IsPrimitive)
                {
                    if (outputType == typeof(string))
                    {
                        // String -> ReadOnlyMemory<char>
                        Host.Assert(colType.IsText);
                        return(CreateConvertingGetterDelegate <String, ReadOnlyMemory <char> >(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory <char> .Empty));
                    }

                    // T -> T
                    if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Host.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
                    }
                    else
                    {
                        Host.Assert(colType.RawType == outputType);
                    }

                    if (!colType.IsKey)
                    {
                        del = CreateDirectGetterDelegate <int>;
                    }
                    else
                    {
                        var keyRawType = colType.RawType;
                        Host.Assert(colType.AsKey.Contiguous);
                        Func <Delegate, ColumnType, Delegate> delForKey = CreateKeyGetterDelegate <uint>;
                        return(Utils.MarshalInvoke(delForKey, keyRawType, peek, colType));
                    }
                }
                else
                {
                    // REVIEW: Is this even possible?
                    throw Host.ExceptNotSupp("Type '{0}' is not yet supported.", outputType.FullName);
                }
                return(Utils.MarshalInvoke(del, genericType, peek));
            }
            //private Delegate CreateGetter(SchemaProxy schema, int index, Delegate peek)
            private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Column column, Delegate peek)
            {
                var outputType  = column.OutputType;
                var genericType = outputType;
                Func <Delegate, Delegate> del;

                if (outputType.IsArray)
                {
                    VectorType vectorType = colType as VectorType;
                    Host.Assert(vectorType != null);

                    // String[] -> ReadOnlyMemory<char>
                    if (outputType.GetElementType() == typeof(string))
                    {
                        Host.Assert(vectorType.ItemType is TextDataViewType);
                        return(CreateConvertingArrayGetterDelegate <string, ReadOnlyMemory <char> >(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory <char> .Empty));
                    }

                    // T[] -> VBuffer<T>
                    if (outputType.GetElementType().IsGenericType&& outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Host.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == vectorType.ItemType.RawType);
                    }
                    else
                    {
                        Host.Assert(outputType.GetElementType() == vectorType.ItemType.RawType);
                    }
                    del         = CreateDirectArrayGetterDelegate <int>;
                    genericType = outputType.GetElementType();
                }
                else if (colType is VectorType vectorType)
                {
                    // VBuffer<T> -> VBuffer<T>
                    // REVIEW: Do we care about accomodating VBuffer<string> -> ReadOnlyMemory<char>?
                    Host.Assert(outputType.IsGenericType);
                    Host.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer <>));
                    Host.Assert(outputType.GetGenericArguments()[0] == vectorType.ItemType.RawType);
                    del         = CreateDirectVBufferGetterDelegate <int>;
                    genericType = vectorType.ItemType.RawType;
                }
                else if (colType is PrimitiveDataViewType)
                {
                    if (outputType == typeof(string))
                    {
                        // String -> ReadOnlyMemory<char>
                        Host.Assert(colType is TextDataViewType);
                        return(CreateConvertingGetterDelegate <String, ReadOnlyMemory <char> >(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory <char> .Empty));
                    }

                    // T -> T
                    if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Host.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
                    }
                    else
                    {
                        Host.Assert(colType.RawType == outputType);
                    }

                    if (!(colType is KeyType keyType))
                    {
                        del = CreateDirectGetterDelegate <int>;
                    }
                    else
                    {
                        var keyRawType = colType.RawType;
                        Func <Delegate, DataViewType, Delegate> delForKey = CreateKeyGetterDelegate <uint>;
                        return(Utils.MarshalInvoke(delForKey, keyRawType, peek, colType));
                    }
                }
        private static Float[] Train(IHost host, ColInfo[] infos, Arguments args, IDataView trainingData)
        {
            Contracts.AssertValue(host, "host");
            host.AssertNonEmpty(infos);

            var       avgDistances  = new Float[infos.Length];
            const int reservoirSize = 5000;

            bool[] activeColumns = new bool[trainingData.Schema.ColumnCount];
            for (int i = 0; i < infos.Length; i++)
            {
                activeColumns[infos[i].Source] = true;
            }

            var reservoirSamplers = new ReservoirSamplerWithReplacement <VBuffer <Float> > [infos.Length];

            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                var rng = args.Seed.HasValue ? RandomUtils.Create(args.Seed) : host.Rand;
                for (int i = 0; i < infos.Length; i++)
                {
                    if (infos[i].TypeSrc.IsVector)
                    {
                        var get = cursor.GetGetter <VBuffer <Float> >(infos[i].Source);
                        reservoirSamplers[i] = new ReservoirSamplerWithReplacement <VBuffer <Float> >(rng, reservoirSize, get);
                    }
                    else
                    {
                        var   getOne = cursor.GetGetter <Float>(infos[i].Source);
                        Float val    = 0;
                        ValueGetter <VBuffer <Float> > get =
                            (ref VBuffer <Float> dst) =>
                        {
                            getOne(ref val);
                            dst = new VBuffer <float>(1, new[] { val });
                        };
                        reservoirSamplers[i] = new ReservoirSamplerWithReplacement <VBuffer <Float> >(rng, reservoirSize, get);
                    }
                }

                while (cursor.MoveNext())
                {
                    for (int i = 0; i < infos.Length; i++)
                    {
                        reservoirSamplers[i].Sample();
                    }
                }
                for (int i = 0; i < infos.Length; i++)
                {
                    reservoirSamplers[i].Lock();
                }
            }

            for (int iinfo = 0; iinfo < infos.Length; iinfo++)
            {
                var instanceCount = reservoirSamplers[iinfo].NumSampled;

                // If the number of pairs is at most the maximum reservoir size / 2, we go over all the pairs,
                // so we get all the examples. Otherwise, get a sample with replacement.
                VBuffer <Float>[] res;
                int resLength;
                if (instanceCount < reservoirSize && instanceCount * (instanceCount - 1) <= reservoirSize)
                {
                    res       = reservoirSamplers[iinfo].GetCache();
                    resLength = reservoirSamplers[iinfo].Size;
                    Contracts.Assert(resLength == instanceCount);
                }
                else
                {
                    res       = reservoirSamplers[iinfo].GetSample().ToArray();
                    resLength = res.Length;
                }

                // If the dataset contains only one valid Instance, then we can't learn anything anyway, so just return 1.
                if (instanceCount <= 1)
                {
                    avgDistances[iinfo] = 1;
                }
                else
                {
                    Float[] distances;
                    var     sub = args.Column[iinfo].MatrixGenerator;
                    if (sub == null)
                    {
                        sub = args.MatrixGenerator;
                    }
                    // create a dummy generator in order to get its type.
                    // REVIEW this should be refactored. See https://github.com/dotnet/machinelearning/issues/699
                    var  matrixGenerator = sub.CreateComponent(host, 1);
                    bool gaussian        = matrixGenerator is GaussianFourierSampler;

                    // If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs.
                    if (resLength < reservoirSize)
                    {
                        distances = new Float[instanceCount * (instanceCount - 1) / 2];
                        int count = 0;
                        for (int i = 0; i < instanceCount; i++)
                        {
                            for (int j = i + 1; j < instanceCount; j++)
                            {
                                distances[count++] = gaussian ? VectorUtils.L2DistSquared(ref res[i], ref res[j])
                                    : VectorUtils.L1Distance(ref res[i], ref res[j]);
                            }
                        }
                        host.Assert(count == distances.Length);
                    }
                    else
                    {
                        distances = new Float[reservoirSize / 2];
                        for (int i = 0; i < reservoirSize - 1; i += 2)
                        {
                            // For Gaussian kernels, we scale by the L2 distance squared, since the kernel function is exp(-gamma ||x-y||^2).
                            // For Laplacian kernels, we scale by the L1 distance, since the kernel function is exp(-gamma ||x-y||_1).
                            distances[i / 2] = gaussian ? VectorUtils.L2DistSquared(ref res[i], ref res[i + 1]) :
                                               VectorUtils.L1Distance(ref res[i], ref res[i + 1]);
                        }
                    }

                    // If by chance, in the random permutation all the pairs are the same instance we return 1.
                    Float median = MathUtils.GetMedianInPlace(distances, distances.Length);
                    avgDistances[iinfo] = median == 0 ? 1 : median;
                }
            }
            return(avgDistances);
        }
示例#9
0
        private FieldAwareFactorizationMachineModelParameters TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data,
                                                                        RoleMappedData validData = null, FieldAwareFactorizationMachineModelParameters predictor = null)
        {
            _host.AssertValue(ch);
            _host.AssertValue(pch);

            data.CheckBinaryLabel();
            var featureColumns    = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
            int fieldCount        = featureColumns.Count;
            int totalFeatureCount = 0;

            int[] fieldColumnIndexes = new int[fieldCount];
            for (int f = 0; f < fieldCount; f++)
            {
                var col = featureColumns[f];
                _host.Assert(!col.IsHidden);
                if (!(col.Type is VectorDataViewType vectorType) ||
                    !vectorType.IsKnownSize ||
                    vectorType.ItemType != NumberDataViewType.Single)
                {
                    throw ch.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of Single, but has type: {1}.", col.Name, col.Type);
                }
                _host.Assert(vectorType.Size > 0);
                fieldColumnIndexes[f] = col.Index;
                totalFeatureCount    += vectorType.Size;
            }
            ch.Check(checked (totalFeatureCount * fieldCount * _latentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension or the number of fields too large");
            if (predictor != null)
            {
                ch.Check(predictor.FeatureCount == totalFeatureCount, "Input model's feature count mismatches training feature count");
                ch.Check(predictor.LatentDimension == _latentDim, "Input model's latent dimension mismatches trainer's");
            }
            if (validData != null)
            {
                validData.CheckBinaryLabel();
                var validFeatureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
                _host.Assert(fieldCount == validFeatureColumns.Count);
                for (int f = 0; f < fieldCount; f++)
                {
                    var featCol      = featureColumns[f];
                    var validFeatCol = validFeatureColumns[f];
                    _host.Assert(featCol.Name == validFeatCol.Name);
                    _host.Assert(featCol.Type == validFeatCol.Type);
                }
            }
            bool shuffle = _shuffle;

            if (shuffle && !data.Data.CanShuffle)
            {
                ch.Warning("Training data does not support shuffling, so ignoring request to shuffle");
                shuffle = false;
            }
            var rng                = shuffle ? _host.Rand : null;
            var featureGetters     = new ValueGetter <VBuffer <float> > [fieldCount];
            var featureBuffer      = new VBuffer <float>();
            var featureValueBuffer = new float[totalFeatureCount];
            var featureIndexBuffer = new int[totalFeatureCount];
            var featureFieldBuffer = new int[totalFeatureCount];
            var latentSum          = new AlignedArray(fieldCount * fieldCount * _latentDimAligned, 16);
            var metricNames        = new List <string>()
            {
                "Training-loss"
            };

            if (validData != null)
            {
                metricNames.Add("Validation-loss");
            }
            int    iter                 = 0;
            long   exampleCount         = 0;
            long   badExampleCount      = 0;
            long   validBadExampleCount = 0;
            double loss                 = 0;
            double validLoss            = 0;

            pch.SetHeader(new ProgressHeader(metricNames.ToArray(), new string[] { "iterations", "examples" }), entry =>
            {
                entry.SetProgress(0, iter, _numIterations);
                entry.SetProgress(1, exampleCount);
            });

            var columns = data.Schema.Schema.Where(x => fieldColumnIndexes.Contains(x.Index)).ToList();

            columns.Add(data.Schema.Label.Value);
            if (data.Schema.Weight != null)
            {
                columns.Add(data.Schema.Weight.Value);
            }

            InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights,
                                    out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned);

            // refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
            while (iter++ < _numIterations)
            {
                using (var cursor = data.Data.GetRowCursor(columns, rng))
                {
                    var labelGetter  = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index);
                    var weightGetter = data.Schema.Weight?.Index is int weightIdx?RowCursorUtils.GetGetterAs <float>(NumberDataViewType.Single, cursor, weightIdx) : null;

                    for (int i = 0; i < fieldCount; i++)
                    {
                        featureGetters[i] = cursor.GetGetter <VBuffer <float> >(cursor.Schema[fieldColumnIndexes[i]]);
                    }
                    loss            = 0;
                    exampleCount    = 0;
                    badExampleCount = 0;
                    while (cursor.MoveNext())
                    {
                        float label         = 0;
                        float weight        = 1;
                        int   count         = 0;
                        float modelResponse = 0;
                        labelGetter(ref label);
                        weightGetter?.Invoke(ref weight);
                        float annihilation = label - label + weight - weight;
                        if (!FloatUtils.IsFinite(annihilation))
                        {
                            badExampleCount++;
                            continue;
                        }
                        if (!FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(featureGetters, featureBuffer, _norm, ref count,
                                                                                          featureFieldBuffer, featureIndexBuffer, featureValueBuffer))
                        {
                            badExampleCount++;
                            continue;
                        }

                        // refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
                        FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count,
                                                                                               featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse);
                        var slope = CalculateLossSlope(label, modelResponse);

                        // refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
                        FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count,
                                                                                           featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned);
                        loss += weight * CalculateLoss(label, modelResponse);
                        exampleCount++;
                    }
                    loss /= exampleCount;
                }

                if (_verbose)
                {
                    if (validData == null)
                    {
                        pch.Checkpoint(loss, iter, exampleCount);
                    }
                    else
                    {
                        validLoss = CalculateAvgLoss(ch, validData, _norm, linearWeights, latentWeightsAligned, _latentDimAligned, latentSum,
                                                     featureFieldBuffer, featureIndexBuffer, featureValueBuffer, featureBuffer, ref validBadExampleCount);
                        pch.Checkpoint(loss, validLoss, iter, exampleCount);
                    }
                }
            }
            if (badExampleCount != 0)
            {
                ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set");
            }
            if (validBadExampleCount != 0)
            {
                ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set");
            }

            return(new FieldAwareFactorizationMachineModelParameters(_host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned));
        }
示例#10
0
        public void Run()
        {
            string template;

            using (var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(CodeTemplatePath))
                using (var reader = new StreamReader(stream))
                    template = reader.ReadToEnd();

            var codeProvider = new CSharpCodeProvider();

            using (var fs = File.OpenRead(_args.InputModelFile))
            {
                var transformPipe = ModelFileUtils.LoadPipeline(_host, fs, new MultiFileSource(null), true);
                var pred          = _host.LoadPredictorOrNull(fs);

                IDataView root;
                for (root = transformPipe; root is IDataTransform; root = ((IDataTransform)root).Source)
                {
                    ;
                }

                // root is now the loader.
                _host.Assert(root is IDataLoader);

                // Loader columns.
                var loaderSb = new StringBuilder();
                for (int i = 0; i < root.Schema.ColumnCount; i++)
                {
                    if (root.Schema.IsHidden(i))
                    {
                        continue;
                    }
                    if (loaderSb.Length > 0)
                    {
                        loaderSb.AppendLine();
                    }

                    ColumnType colType = root.Schema.GetColumnType(i);
                    CodeGenerationUtils.AppendFieldDeclaration(codeProvider, loaderSb, i, root.Schema.GetColumnName(i), colType, true, _args.SparseVectorDeclaration);
                }

                // Scored example columns.
                IDataView scorer;
                if (pred == null)
                {
                    scorer = transformPipe;
                }
                else
                {
                    var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs);
                    scorer = roles != null
                        ? _host.CreateDefaultScorer(new RoleMappedData(transformPipe, roles, opt : true), pred)
                             : _host.CreateDefaultScorer(new RoleMappedData(transformPipe, label : null, "Features"), pred);
                }

                var nonScoreSb = new StringBuilder();
                var scoreSb    = new StringBuilder();
                for (int i = 0; i < scorer.Schema.ColumnCount; i++)
                {
                    if (scorer.Schema.IsHidden(i))
                    {
                        continue;
                    }
                    bool isScoreColumn = scorer.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnSetId, i) != null;

                    var sb = isScoreColumn ? scoreSb : nonScoreSb;

                    if (sb.Length > 0)
                    {
                        sb.AppendLine();
                    }

                    ColumnType colType = scorer.Schema.GetColumnType(i);
                    CodeGenerationUtils.AppendFieldDeclaration(codeProvider, sb, i, scorer.Schema.GetColumnName(i), colType, false, _args.SparseVectorDeclaration);
                }

                // Turn model path into a C# identifier and insert it.
                var modelPath = !string.IsNullOrWhiteSpace(_args.ModelNameOverride) ? _args.ModelNameOverride : _args.InputModelFile;
                modelPath = CodeGenerationUtils.GetCSharpString(codeProvider, modelPath);
                modelPath = string.Format("modelPath = {0};", modelPath);

                // Replace values inside the template.
                var replacementMap =
                    new Dictionary <string, string>
                {
                    { "EXAMPLE_CLASS_DECL", loaderSb.ToString() },
                    { "SCORED_EXAMPLE_CLASS_DECL", nonScoreSb.ToString() },
                    { "SCORE_CLASS_DECL", scoreSb.ToString() },
                    { "MODEL_PATH", modelPath }
                };

                var classSource = CodeGenerationUtils.MultiReplace(template, replacementMap);
                File.WriteAllText(_args.CSharpOutput, classSource);
            }
        }
示例#11
0
        // This method is called if only a datafile is specified, without a loader/term and value columns.
        // It determines the type of the Value column and returns the appropriate TextLoader component factory.
        private static IComponentFactory <IMultiStreamSource, IDataLoader> GetLoaderFactory(string filename, bool keyValues, IHost host)
        {
            Contracts.AssertValue(host);

            // If the user specified non-key values, we define the value column to be numeric.
            if (!keyValues)
            {
                return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                           (env, files) => TextLoader.Create(
                               env,
                               new TextLoader.Arguments()
                {
                    Column = new[]
                    {
                        new TextLoader.Column("Term", DataKind.TX, 0),
                        new TextLoader.Column("Value", DataKind.Num, 1)
                    }
                },
                               files)));
            }

            // If the user specified key values, we scan the values to determine the range of the key type.
            ulong min = ulong.MaxValue;
            ulong max = ulong.MinValue;

            try
            {
                var  txtArgs = new TextLoader.Arguments();
                bool parsed  = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs);
                host.Assert(parsed);
                var data = TextLoader.ReadFile(host, txtArgs, new MultiFileSource(filename));
                using (var cursor = data.GetRowCursor(c => true))
                {
                    var getTerm = cursor.GetGetter <ReadOnlyMemory <char> >(0);
                    var getVal  = cursor.GetGetter <ReadOnlyMemory <char> >(1);
                    ReadOnlyMemory <char> txt = default;

                    using (var ch = host.Start("Creating Text Lookup Loader"))
                    {
                        long countNonKeys = 0;
                        while (cursor.MoveNext())
                        {
                            getVal(ref txt);
                            ulong res;
                            // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0,
                            // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for
                            // computing max and min.
                            if (Conversions.Instance.TryParseKey(in txt, 1, ulong.MaxValue, out res))
                            {
                                if (res < min && res != 0)
                                {
                                    min = res;
                                }
                                if (res > max)
                                {
                                    max = res;
                                }
                            }
                            // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds,
                            // then the value is 0, and we update min accordingly.
                            else if (Conversions.Instance.TryParse(in txt, out res))
                            {
                                ch.Assert(res == 0);
                                min = 0;
                            }
        public CountTableTransformer Fit(IDataView input)
        {
            var labelCol = input.Schema.GetColumnOrNull(_labelColumnName);

            if (labelCol == null)
            {
                throw _host.ExceptUserArg(nameof(_labelColumnName), "Label column '{0}' not found", _labelColumnName);
            }

            CheckLabelType(new RoleMappedData(input, roles: RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, _labelColumnName)), out var labelCardinality);

            var labelColumnType = labelCol.GetValueOrDefault().Type;
            var labelClassNames = InitLabelClassNames(_host, labelCol.GetValueOrDefault(), labelCardinality);

            var n = _columns.Length;

            var inputColumns = new DataViewSchema.Column[_columns.Length];

            for (int i = 0; i < inputColumns.Length; i++)
            {
                var col = input.Schema.GetColumnOrNull(_columns[i].InputColumnName);
                if (col == null)
                {
                    throw _host.Except($"Could not find column {_columns[i].InputColumnName} in input schema");
                }
                inputColumns[i] = col.GetValueOrDefault();
            }

            _host.Assert(_initialCounts != null || _sharedBuilder != null || _builders != null);
            MultiCountTableBuilderBase multiBuilder;

            if (_initialCounts != null)
            {
                multiBuilder = _initialCounts.Featurizer.MultiCountTable.ToBuilder(_host, inputColumns, labelCardinality);
            }
            else if (_builders != null)
            {
                multiBuilder = new ParallelMultiCountTableBuilder(_host, inputColumns, _builders, labelCardinality);
            }
            else
            {
                multiBuilder = new BagMultiCountTableBuilder(_host, inputColumns, _sharedBuilder, labelCardinality);
            }

            var cols = new List <DataViewSchema.Column>();

            foreach (var c in _columns)
            {
                var col = input.Schema.GetColumnOrNull(c.InputColumnName);
                _host.Assert(col.HasValue);
                cols.Add(col.Value);
            }

            TrainTables(input, cols, multiBuilder, labelCol.GetValueOrDefault());

            var multiCountTable = multiBuilder.CreateMultiCountTable();

            var featurizer = new CountTargetEncodingFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable);

            return(new CountTableTransformer(_host, featurizer, labelClassNames,
                                             _columns.Select(col => col.Seed).ToArray(), _columns.Select(col => (col.Name, col.InputColumnName)).ToArray()));
        }
        public void Save(ModelSaveContext ctx)
        {
            _host.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // byte: indicator for frozen models
            // stream: tensorFlow model.
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            var isFrozen = string.IsNullOrEmpty(_savedModelPath);

            ctx.Writer.WriteBoolByte(isFrozen);
            if (isFrozen)
            {
                var buffer = new TFBuffer();
                Session.Graph.ToGraphDef(buffer);
                ctx.SaveBinaryStream("TFModel", w =>
                {
                    w.WriteByteArray(buffer.ToArray());
                });
            }
            else
            {
                ctx.SaveBinaryStream("TFSavedModel", w =>
                {
                    string[] modelFilePaths = Directory.GetFiles(_savedModelPath, "*", SearchOption.AllDirectories);
                    w.Write(modelFilePaths.Length);

                    foreach (var fullPath in modelFilePaths)
                    {
                        var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
                        w.Write(relativePath);

                        using (var fs = new FileStream(fullPath, FileMode.Open))
                        {
                            long fileLength = fs.Length;
                            w.Write(fileLength);
                            long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
                            _host.Assert(actualWritten == fileLength);
                        }
                    }
                });
            }
            _host.AssertNonEmpty(Inputs);
            ctx.Writer.Write(Inputs.Length);
            foreach (var colName in Inputs)
            {
                ctx.SaveNonEmptyString(colName);
            }

            _host.AssertNonEmpty(Outputs);
            ctx.Writer.Write(Outputs.Length);
            foreach (var colName in Outputs)
            {
                ctx.SaveNonEmptyString(colName);
            }
        }