/// <summary> /// Build the feature package including NTB likelihoods for the specified query. Doesn't include the target likelihood. /// </summary> /// <param name="query"></param> /// <param name="partialQUery"></param> /// <returns></returns> private FeaturePackage BuildFeaturePackage(BaseIndexEntry prefixEntry, string partialQuery, bool includeQueryCountsFeature = false) { FeaturePackage tpq = new FeaturePackage(); tpq.Query = prefixEntry.Query; tpq.NtbFeatures = new double[_ntbs.Length]; // Set the NTB counts for (int i = 0; i < _ntbs.Length; i++) { double queryFrequency = _ntbs[i][partialQuery].GetQueryFrequency(prefixEntry.Query); // Calculate probability and apply tpq.NtbFeatures[i] = queryFrequency; // TODO: don't normalise for now / Convert.ToDouble(ntbN); // Include query counts feature if necessary if (includeQueryCountsFeature) { // IGNORE THIS FOR NOW //tpq.QueriesSinceLastTrain = _currentQueryCount - _queryCountsForPrefix[partialQuery].OldestItem(); } } return tpq; }
/// <summary> /// Train the model with a FeaturePackage for a query /// </summary> /// <param name="features"></param> /// <param name="target"></param> public void TrainModel(FeaturePackage fp) { double[] instance = fp.NtbFeatures; //new double[fp.NtbFeatures.Length + 1]; // Prepend the 1.0 intercept parameter to the features //instance[0] = 1.0; //for (int i = 0; i < fp.NtbFeatures.Length; i++) // instance[i + 1] = fp.NtbFeatures[i]; // Set the rest of the features // Set scale factors for (int i = 0; i < instance.Length; i++) instance[i] = instance[i] / _maxNtbSize; double target = fp.TargetLikelihood / _predictionHorizon; // Predict in next 100 - so use 100 scaling factor double prediction = Predict(instance); double squaredError = Math.Pow(target - prediction, 2); _totalSqrdError += squaredError; // Iterate using SGD for each parameter for (int j = 0; j < _modelWeights.Length; j++) { double error = (prediction - target); _modelWeights[j] = _modelWeights[j] - ( _learningRateAlpha * error * instance[j] ); } _trainingInstances++; if (_trainingInstances % 2000 == 0) { string modelStr = ""; foreach (double weight in _modelWeights) modelStr += weight.ToString("F8") + ", "; // Output Console.WriteLine("SGD LR Training instances at " + _trainingInstances.ToString() + ", avg sqrd error: " + (_totalSqrdError / Convert.ToDouble(_trainingInstances)).ToString() + ", model: " + modelStr); } }