/// <summary> /// Picks a random query from the list of queries. However, it filters the list of /// queries to only be relevant to a provided datavector. /// </summary> /// <param name="dataVector"></param> /// <param name="rand"></param> /// <returns></returns> public Query GetRandomQuery(DataVectorTraining dataVector, Random rand) { //Try to add new details AddMissingQueriesAndLabels(dataVector); //If no queries available, all features are in the the state. if (Queries.Count == 0) { return(null); } //Build list of possible queries, that match datavector var possibleQueries = Queries.Where(q => dataVector.Features.Find(f => q.Key.Feature.Equals(f)) != null ).ToList(); //If no possibilities if (possibleQueries.Count == 0) { return(null); } //Pick random query from possibilities return(possibleQueries[rand.Next(possibleQueries.Count)].Key); }
/// <summary> /// Updates the network of states used for deciding the label of a future datavector. A training datavector /// is provided which contains all features, values, relative rewards, and the correct classification label. /// </summary> /// <param name="dataVector">A sample data point to learn from with features, values, relative rewards, and correct classification label.</param> /// <returns>The statistics of the learning process. See "TrainingStats" class for more details.</returns> public TrainingStats Learn(DataVectorTraining dataVector) { lock (processLock) { //Check datavector dataVector.Features.RemoveAll(p => p.Value == null); if (dataVector.Features.Count == 0 || dataVector.Label == null || dataVector.Label.Value == null) { return(new TrainingStats()); } //Clear current decision tree DecisionTree = null; //Training statistics TrainingStats trainingDetails = new TrainingStats(); //Create root state, if it does not exist if (StateSpace.Count == 0) { AddState(new State(dataVector), trainingDetails); } //Start with root state State rootState = StateSpace[0]; // 0 is the hashcode for a state with no features. Learn(rootState, dataVector, 0, trainingDetails); //Statistics trainingDetails.StatesTotal = StateSpace.Count; //Return return(trainingDetails); } }
//Constructors /// <summary> /// Creates a new state by combining an existing state and new feature. Queries are updated using the datavector. /// </summary> /// <param name="original"></param> /// <param name="additionalFeature"></param> /// <param name="dataVector"></param> public State(State original, FeatureValuePair additionalFeature, DataVectorTraining dataVector) : this(original, dataVector) { ////Check for disposed and null //if (additionalFeature == null) // throw new ArgumentNullException("additionalFeature"); //if (additionalFeature.IsDisposed) // throw new ArgumentException("Parameter is disposed.", "additionalFeature"); //Add the feature AddFeature(additionalFeature); }
/// <summary> /// Creates a state from an existing state with no new features but a complete list of possible queries. /// </summary> /// <param name="original"></param> /// <param name="dataVector"></param> public State(State original, DataVectorTraining dataVector) : this(dataVector) { ////Check for disposed and null //if (original == null) // throw new ArgumentNullException("original"); //if (original.IsDisposed) // throw new ArgumentException("Parameter is disposed.", "original"); //Add features from original. The list can't be directly copied because of synchronization with the queries list. foreach (FeatureValuePair theFeature in original.Features) { AddFeature(theFeature); } }
/// <summary> /// Creates a new state with the specified features. Appropriate queries are built using the datavector. /// </summary> /// <param name="features">The features to be added to this state.</param> /// <param name="dataVector">An example data vector to create initial queries.</param> public State(List <FeatureValuePair> features, DataVectorTraining dataVector) : this(dataVector) { ////Check for disposed and null //if (features == null) // throw new ArgumentNullException("features"); //if (features.Where(p=>p.IsDisposed).ToList().Count > 0) // throw new ArgumentException("Features may not be disposed.", "features"); //Add features foreach (FeatureValuePair theFeature in features) { AddFeature(theFeature); } }
/// <summary> /// Updates all queries in other states that lead to this state. /// </summary> /// <param name="nextState">The state that comes after the query is performed.</param> /// <param name="dataVector">The relevant datavector for trainging.</param> private void ParallelPathsUpdate(State nextState, DataVectorTraining dataVector, TrainingStats trainingDetails) { lock (processLock) { //Get current expected reward of label double nextStateLabelReward = nextState.Labels[dataVector.Label]; //Adjust queries in states that point to this "nextState". List <FeatureValuePair> nextStateFeatures = nextState.Features.ToList(); foreach (FeatureValuePair theFeature in nextStateFeatures) { //Generate hashcode of a state that is missing this feature. i.e. A state that is only different by one feature, so it could lead to this state. int stateHashcode = nextState.GetHashCodeWithout(theFeature); //If the state exists, get it. State prevState = null; if (StateSpace.ContainsKey(stateHashcode)) { //Get the state prevState = StateSpace[stateHashcode]; } else { //Copy list of features List <FeatureValuePair> prevStateFeatures = nextStateFeatures.ToList(); //Remove unwanted feature prevStateFeatures.Remove(theFeature); //Create a new state prevState = new State(prevStateFeatures, dataVector); AddState(prevState, trainingDetails); //StateSpace.Add(prevState.GetHashCode(), prevState); trainingDetails.StatesCreated++; continue; } //Create the query to update Query theQuery = new Query(theFeature, dataVector.Label); //Get reward from datavector for querying this feature double featureReward = dataVector[theFeature.Name].Importance; //Adjust the query prevState.AdjustQuery(theQuery, nextStateLabelReward, featureReward, DiscountFactor); } } }
/// <summary> /// Creates a "root" state with only queries. /// </summary> /// <param name="dataVector"></param> public State(DataVectorTraining dataVector) { ////Check for disposed and null //if (dataVector == null) // throw new ArgumentNullException("dataVector"); //if (dataVector.IsDisposed) // throw new ArgumentException("Parameter is disposed.", "dataVector"); this.Features = new HashSet <FeatureValuePair>(); this.FeatureNames = new HashSet <string>(); this.Queries = new Dictionary <Query, double>(); this.Labels = new Dictionary <FeatureValuePair, double>(); this.LabelsCount = new Dictionary <FeatureValuePair, int>(); //Add missing details AddMissingQueriesAndLabels(dataVector); }
/// <summary> /// All features of the datavector are inspected. If a feature-value combination is not /// in the list of queries, it is added. If a new label is encountered, it is also added. /// New queries are created optomistically with an expected reward of 1. /// New labels are created with an expected reward of 0; /// </summary> /// <param name="dataVector"></param> public void AddMissingQueriesAndLabels(DataVectorTraining dataVector) { //Try to create new queries foreach (FeatureValuePairWithImportance theFeature in dataVector.Features.ToList()) { //Skip features names that are already in feature list. if (FeatureNames.Contains(theFeature.Name)) { continue; } //Create the possibly new query Query newQuery = new Query(theFeature, dataVector.Label); //Try to add the query if (!Queries.ContainsKey(newQuery)) { Queries.Add(newQuery, 1); } } //Try to add the label AdjustLabels(dataVector.Label); }
/// <summary> /// A recursive learning process. A state is updated and analised using a training datavector. /// The labels are initially updated, then it is determined if the current label or visiting another /// state provides greater reward. /// </summary> /// <param name="currentState">The state to be updated and analized.</param> /// <param name="dataVector">The list of features, rewards and label to update with</param> /// <param name="trainingDetails">Provides statistics of the learning process.</param> private void Learn(State currentState, DataVectorTraining dataVector, int totalQueries, TrainingStats trainingDetails) { lock (processLock) { //Choose random or best query Query recommendedQuery = null; if (rand.NextDouble() < ExplorationRate) //Pick random query 10% of the time { //Pick random query Query randomQuery = currentState.GetRandomQuery(dataVector, rand); recommendedQuery = randomQuery; } else { //Find best query Query bestQuery = currentState.GetBestQuery(dataVector); recommendedQuery = bestQuery; } //Check total queries if (totalQueries > QueriesLimit) { recommendedQuery = null; } //Adjust expected reward of labels if (recommendedQuery == null || ParallelReportUpdatesEnabled) { currentState.AdjustLabels(dataVector.Label); } //If no query, then end training for this datapoint if (recommendedQuery == null) { return; } //Search for next state, or create it State nextState = null; int nextHashCode = currentState.GetHashCodeWith(recommendedQuery.Feature); if (StateSpace.ContainsKey(nextHashCode)) { //Get existing state nextState = StateSpace[nextHashCode]; } else { //Create a new state nextState = new State(currentState, recommendedQuery.Feature, dataVector); AddState(nextState, trainingDetails); } //Process next state, to get adjustment for selected query Learn(nextState, dataVector, totalQueries + 1, trainingDetails); trainingDetails.QueriesTotal++; //Update State's Query's expected reward if (ParallelQueryUpdatesEnabled) { //Update all queries that lead to this next state ParallelPathsUpdate(nextState, dataVector, trainingDetails); } else { //Update just current state's query double featureReward = dataVector[recommendedQuery.Feature.Name].Importance; double nextStateLabelReward = nextState.Labels[dataVector.Label]; currentState.AdjustQuery(recommendedQuery, nextStateLabelReward, featureReward, DiscountFactor); } //Return return; } }