private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool norm, float[] linearWeights, AlignedArray latentWeightsAligned, int latentDimAligned, AlignedArray latentSum, int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, VBuffer <float> buffer, ref long badExampleCount) { var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature); var getters = new ValueGetter <VBuffer <float> > [featureColumns.Count]; float label = 0; float weight = 1; double loss = 0; float modelResponse = 0; long exampleCount = 0; badExampleCount = 0; int count = 0; var columns = new List <DataViewSchema.Column>(featureColumns); columns.Add(data.Schema.Label.Value); if (data.Schema.Weight != null) { columns.Add(data.Schema.Weight.Value); } using (var cursor = data.Data.GetRowCursor(columns)) { var labelGetter = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index); var weightGetter = data.Schema.Weight?.Index is int weightIdx?cursor.GetGetter <float>(weightIdx) : null; for (int f = 0; f < featureColumns.Count; f++) { getters[f] = cursor.GetGetter <VBuffer <float> >(featureColumns[f].Index); } while (cursor.MoveNext()) { labelGetter(ref label); weightGetter?.Invoke(ref weight); float annihilation = label - label + weight - weight; if (!FloatUtils.IsFinite(annihilation)) { badExampleCount++; continue; } if (!FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(getters, buffer, norm, ref count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer)) { badExampleCount++; continue; } FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(featureColumns.Count, latentDimAligned, count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse); loss += weight * CalculateLoss(label, modelResponse); exampleCount++; } } return(loss / exampleCount); }
internal float CalculateResponse(ValueGetter <VBuffer <float> >[] getters, VBuffer <float> featureBuffer, int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, AlignedArray latentSum) { int count = 0; float modelResponse = 0; FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(getters, featureBuffer, _norm, ref count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer); FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(FieldCount, LatentDimAligned, count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, _linearWeights, _latentWeightsAligned, latentSum, ref modelResponse); return(modelResponse); }
private FieldAwareFactorizationMachineModelParameters TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, RoleMappedData validData = null, FieldAwareFactorizationMachineModelParameters predictor = null) { Host.AssertValue(ch); Host.AssertValue(pch); data.CheckBinaryLabel(); var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature); int fieldCount = featureColumns.Count; int totalFeatureCount = 0; int[] fieldColumnIndexes = new int[fieldCount]; for (int f = 0; f < fieldCount; f++) { var col = featureColumns[f]; Host.Assert(!col.IsHidden); if (!(col.Type is VectorType vectorType) || !vectorType.IsKnownSize || vectorType.ItemType != NumberType.Float) { throw ch.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type); } Host.Assert(vectorType.Size > 0); fieldColumnIndexes[f] = col.Index; totalFeatureCount += vectorType.Size; } ch.Check(checked (totalFeatureCount * fieldCount * _latentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension or the number of fields too large"); if (predictor != null) { ch.Check(predictor.FeatureCount == totalFeatureCount, "Input model's feature count mismatches training feature count"); ch.Check(predictor.LatentDim == _latentDim, "Input model's latent dimension mismatches trainer's"); } if (validData != null) { validData.CheckBinaryLabel(); var validFeatureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature); Host.Assert(fieldCount == validFeatureColumns.Count); for (int f = 0; f < fieldCount; f++) { var featCol = featureColumns[f]; var validFeatCol = validFeatureColumns[f]; Host.Assert(featCol.Name == validFeatCol.Name); Host.Assert(featCol.Type == validFeatCol.Type); } } bool shuffle = _shuffle; if (shuffle && !data.Data.CanShuffle) { ch.Warning("Training data does not support shuffling, so ignoring request to shuffle"); shuffle = false; } var rng = shuffle ? Host.Rand : null; var featureGetters = new ValueGetter <VBuffer <float> > [fieldCount]; var featureBuffer = new VBuffer <float>(); var featureValueBuffer = new float[totalFeatureCount]; var featureIndexBuffer = new int[totalFeatureCount]; var featureFieldBuffer = new int[totalFeatureCount]; var latentSum = new AlignedArray(fieldCount * fieldCount * _latentDimAligned, 16); var metricNames = new List <string>() { "Training-loss" }; if (validData != null) { metricNames.Add("Validation-loss"); } int iter = 0; long exampleCount = 0; long badExampleCount = 0; long validBadExampleCount = 0; double loss = 0; double validLoss = 0; pch.SetHeader(new ProgressHeader(metricNames.ToArray(), new string[] { "iterations", "examples" }), entry => { entry.SetProgress(0, iter, _numIterations); entry.SetProgress(1, exampleCount); }); var columns = data.Schema.Schema.Where(x => fieldColumnIndexes.Contains(x.Index)).ToList(); columns.Add(data.Schema.Label.Value); if (data.Schema.Weight != null) { columns.Add(data.Schema.Weight.Value); } InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights, out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned); // refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf while (iter++ < _numIterations) { using (var cursor = data.Data.GetRowCursor(columns, rng)) { var labelGetter = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index); var weightGetter = data.Schema.Weight?.Index is int weightIdx?RowCursorUtils.GetGetterAs <float>(NumberType.R4, cursor, weightIdx) : null; for (int i = 0; i < fieldCount; i++) { featureGetters[i] = cursor.GetGetter <VBuffer <float> >(fieldColumnIndexes[i]); } loss = 0; exampleCount = 0; badExampleCount = 0; while (cursor.MoveNext()) { float label = 0; float weight = 1; int count = 0; float modelResponse = 0; labelGetter(ref label); weightGetter?.Invoke(ref weight); float annihilation = label - label + weight - weight; if (!FloatUtils.IsFinite(annihilation)) { badExampleCount++; continue; } if (!FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(featureGetters, featureBuffer, _norm, ref count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer)) { badExampleCount++; continue; } // refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse); var slope = CalculateLossSlope(label, modelResponse); // refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned); loss += weight * CalculateLoss(label, modelResponse); exampleCount++; } loss /= exampleCount; } if (_verbose) { if (validData == null) { pch.Checkpoint(loss, iter, exampleCount); } else { validLoss = CalculateAvgLoss(ch, validData, _norm, linearWeights, latentWeightsAligned, _latentDimAligned, latentSum, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, featureBuffer, ref validBadExampleCount); pch.Checkpoint(loss, validLoss, iter, exampleCount); } } } if (badExampleCount != 0) { ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set"); } if (validBadExampleCount != 0) { ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set"); } return(new FieldAwareFactorizationMachineModelParameters(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned)); }