private void LoadWeightsFromFile()
    {
        if (string.IsNullOrEmpty(weightPath))
        {
            weightPath = Application.dataPath + "/CarRacing/weights.txt";
        }
        StreamReader streamReader = File.OpenText(weightPath);

        if (File.Exists(weightPath))
        {
            string line = streamReader.ReadLine();
            ann.LoadWeights(line);
        }
    }
示例#2
0
文件: ANNDrive.cs 项目: stasoz/AI
    IEnumerator LoadTrainingSet()
    {
        string path = Application.dataPath + "/trainingData.txt";
        string line;

        if (File.Exists(path))
        {
            int           lineCount   = File.ReadAllLines(path).Length;
            StreamReader  tdf         = File.OpenText(path);
            List <double> calcOutputs = new List <double>();
            List <double> inputs      = new List <double>();
            List <double> outputs     = new List <double>();

            for (int i = 0; i < epochs; i++)
            {
                sse = 0;
                tdf.BaseStream.Position = 0;
                string currentWeights = ann.PrintWeights();
                while ((line = tdf.ReadLine()) != null)
                {
                    string[] data      = line.Split(',');
                    float    thisError = 0;
                    if (System.Convert.ToDouble(data[5]) != 0 && System.Convert.ToDouble(data[6]) != 0)
                    {
                        inputs.Clear();
                        outputs.Clear();
                        inputs.Add(System.Convert.ToDouble(data[0]));
                        inputs.Add(System.Convert.ToDouble(data[1]));
                        inputs.Add(System.Convert.ToDouble(data[2]));
                        inputs.Add(System.Convert.ToDouble(data[3]));
                        inputs.Add(System.Convert.ToDouble(data[4]));

                        double o1 = Map(0, 1, -1, 1, System.Convert.ToSingle(data[5]));
                        outputs.Add(o1);
                        double o2 = Map(0, 1, -1, 1, System.Convert.ToSingle(data[6]));
                        outputs.Add(o2);

                        calcOutputs = ann.Train(inputs, outputs);
                        thisError   = ((Mathf.Pow((float)(outputs[0] - calcOutputs[0]), 2) +
                                        Mathf.Pow((float)(outputs[1] - calcOutputs[1]), 2))) / 2.0f;
                    }
                    sse += thisError;
                }
                trainingProgress = (float)i / (float)epochs;
                sse /= lineCount;
                if (lastSSE < sse)
                {
                    ann.LoadWeights(currentWeights);
                    ann.alpha = Mathf.Clamp((float)ann.alpha - 0.001f, 0.01f, 0.9f);
                }
                else
                {
                    ann.alpha = Mathf.Clamp((float)ann.alpha + 0.001f, 0.01f, 0.9f);
                    lastSSE   = sse;
                }
                yield return(null);
            }
        }
        trainingDone = true;
    }
    /// <summary>
    /// Adapts alpha value of the Neural Network dynamically.
    /// Discards the current iteration's results if no improvement was made.
    /// </summary>
    /// <param name="currentWeights"> The weights to be re-loaded, if no improvement was made. </param>
    private void AdaptLearning(string currentWeights)
    {
        if (lastSumSquaredError < sumSquaredError)
        {
            // SumSquaredError hasn't improved over lastSumquaredError.

            ann.LoadWeights(currentWeights);
            // Decrease alpha.
            ann.alpha = Mathf.Clamp((float)ann.alpha - 0.001f, 0.01f, 0.9f);
        }
        else
        {
            // SumSquaredError has improved.

            // Increase alpha.
            ann.alpha           = Mathf.Clamp((float)ann.alpha + 0.001f, 0.01f, 0.9f);
            lastSumSquaredError = sumSquaredError;
        }
    }
示例#4
0
 // Update is called once per frame
 void Update()
 {
     if (Input.GetKey("space"))
     {
         Debug.Log(ann.PrintWeights());
     }
     if (Input.GetKey(KeyCode.Return))
     {
         ann.LoadWeights();
     }
 }
示例#5
0
    void LoadWeightsFromFile(ANN ann)
    {
        string       path = Application.dataPath + "/weights.txt";
        StreamReader wf   = File.OpenText(path);

        if (File.Exists(path))
        {
            string line = wf.ReadLine();
            ann.LoadWeights(line);
        }
    }
示例#6
0
    public void Start()
    {
        ann = new ANN(5, 2, 1, 10, 0.5f);

        var aux = PlayerPrefs.GetString("Weights");

        //if (aux != null)
        //ann.LoadWeights(aux);
        if (_isAleatoryCircuit)
        {
            if (manager.GetDictionaryCreate() && manager.GetNameDictionary(currentAleatoryCircuitName))
            {
                Debug.Log(currentAleatoryCircuitName);
                Debug.Log(manager.GetDataDictionary(currentAleatoryCircuitName));
                ann.LoadWeights(manager.GetDataDictionary(currentAleatoryCircuitName));
            }

            /*
             * if (SaveAndLoad.SaveExists(currentAleatoryCircuitName+".txt"))
             * {
             *  Debug.Log("Cargar");
             *  List<string> loadContent = new List<string>();
             *  loadContent = SaveAndLoad.Load<List<string>>(currentAleatoryCircuitName+".txt");
             *  Debug.Log(loadContent[0] + "\n" + loadContent[1]);
             *  ann.LoadWeights(loadContent[1]);
             * }*/
        }
        else
        {
            if (managerCircuits.GetDictionaryCreate() && managerCircuits.GetNameDictionary(circuitName))
            {
                Debug.Log(managerCircuits.GetDataDictionary(circuitName));
                ann.LoadWeights(managerCircuits.GetDataDictionary(circuitName));
            }
        }

        ballStartPos   = ball.transform.position;
        ballStartRot   = ball.transform.rotation;
        Time.timeScale = timeScaleValue;        //
    }
示例#7
0
    public void LoadWeights()
    {
        ann = bird.GetANN();
        string weights, bias;

        System.IO.StreamReader sr = new System.IO.StreamReader(@"C:\Users\Ghost\Desktop\weights.txt");
        weights = sr.ReadToEnd();
        weights = weights.Replace(",", ".");
        sr.Close();

        System.IO.StreamReader sr2 = new System.IO.StreamReader(@"C:\Users\Ghost\Desktop\bias.txt");
        bias = sr2.ReadToEnd();
        bias = bias.Replace(",", ".");
        sr2.Close();
        //
        string[] splitWeights = weights.Split('!');
        string[] splitBias    = bias.Split('!');
        //
        List <double> weightsList = new List <double>();
        List <double> biasList    = new List <double>();

        foreach (string str in splitWeights)
        {
            if (str == "" || string.IsNullOrEmpty(str))
            {
                continue;
            }

            //  double tmp;
            //  double.TryParse(str, out tmp);
            double value = System.Convert.ToDouble(str, System.Globalization.CultureInfo.InvariantCulture);
            weightsList.Add(value);
        }
        foreach (string str in splitBias)
        {
            if (str == "" || string.IsNullOrEmpty(str))
            {
                continue;
            }

            double value = System.Convert.ToDouble(str, System.Globalization.CultureInfo.InvariantCulture);
            biasList.Add(value);
        }

        ann.LoadWeights(weightsList, biasList);
    }
示例#8
0
    // Method to perform the ANN training using data collected from the player.
    IEnumerator LoadTrainingSet()
    {
        string path = Application.dataPath + "/trainingData.txt";
        string line;

        if (File.Exists(path))
        {
            int          lineCount = File.ReadAllLines(path).Length;
            StreamReader tdf       = File.OpenText(path);

            List <double> calcOutputs = new List <double>();
            List <double> inputs      = new List <double>();
            List <double> outputs     = new List <double>();

            //Loop through the epochs
            for (int i = 0; i < epochs; i++)
            {
                //set file pointer to beginning of file
                sse = 0;
                tdf.BaseStream.Position = 0;

                //Get the current weight comma separated string values from the ANN object
                string currentWeights = ann.PrintWeights();

                // Load the training data, line by line
                while ((line = tdf.ReadLine()) != null)
                {
                    string[] data = line.Split(',');
                    //if nothing to be learned ignore this line
                    float thisError = 0;

                    //We are leaving out those data where we have training labels or y (translation and rotation values) with values zero to reduce the data set.
                    if (System.Convert.ToDouble(data[5]) != 0 && System.Convert.ToDouble(data[6]) != 0) //If translation and rotation outputs in training data are not zero
                    {
                        //Clear out lists from previous row
                        inputs.Clear();
                        outputs.Clear();

                        //Add training input data to the inputs to the ANN
                        inputs.Add(System.Convert.ToDouble(data[0]));
                        inputs.Add(System.Convert.ToDouble(data[1]));
                        inputs.Add(System.Convert.ToDouble(data[2]));
                        inputs.Add(System.Convert.ToDouble(data[3]));
                        inputs.Add(System.Convert.ToDouble(data[4]));

                        //Map labels to range (0,1) for efficient training
                        double o1 = Map(0, 1, -1, 1, System.Convert.ToSingle(data[5]));
                        outputs.Add(o1);
                        double o2 = Map(0, 1, -1, 1, System.Convert.ToSingle(data[6]));
                        outputs.Add(o2);

                        //Calculated output (y-hat)
                        calcOutputs = ann.Train(inputs, outputs);
                        //Sum squared Error value: for both labels
                        thisError = ((Mathf.Pow((float)(outputs[0] - calcOutputs[0]), 2) +
                                      Mathf.Pow((float)(outputs[1] - calcOutputs[1]), 2))) / 2.0f;
                    }
                    //Add this to cumulative SSE for the epoch
                    sse += thisError;
                }

                //Percentage training to display on screen
                trainingProgress = (float)i / (float)epochs;

                // Average SSE
                sse /= lineCount;

                //If sse isn't better then reload previous set of weights and decrease alpha. This adaptive training to let the ANN move out
                // of local optima and hence find global optima.
                if (lastSSE < sse)
                {
                    ann.LoadWeights(currentWeights);
                    ann.alpha = Mathf.Clamp((float)ann.alpha - 0.001f, 0.01f, 0.9f);
                }
                else //increase alpha
                {
                    ann.alpha = Mathf.Clamp((float)ann.alpha + 0.001f, 0.01f, 0.9f);
                    lastSSE   = sse;
                }

                yield return(null); //Allow OnGUI some time to update on-screen values
            }
        }
        //Training done, save weights
        trainingDone = true;
        SaveWeightsToFile();
    }
示例#9
0
    IEnumerator LoadTrainingSet()
    {
        string path = ann.GetPath("trainingData");
        string line;

        if (File.Exists(path))
        {
            int          lineCount = File.ReadAllLines(path).Length;
            StreamReader tdf       = File.OpenText(path);

            List <double> calcOutputs = new List <double>();
            List <double> inputs      = new List <double>();
            List <double> outputs     = new List <double>();

            for (int i = 0; i < epochs; i++)
            {
                // set file pointer to beginning of file
                sse = 0;
                tdf.BaseStream.Position = 0;
                string currentWeights = ann.PrintWeights();
                while ((line = tdf.ReadLine()) != null)
                {
                    string[] data = line.Split(',');
                    // if nothing to be learned ignore this line
                    float thisError = 0;
                    if (System.Convert.ToDouble(data[5]) != 0 && System.Convert.ToDouble(data[6]) != 0)
                    {
                        inputs.Clear();
                        outputs.Clear();
                        inputs.Add(System.Convert.ToDouble(data[0]));
                        inputs.Add(System.Convert.ToDouble(data[1]));
                        inputs.Add(System.Convert.ToDouble(data[2]));
                        inputs.Add(System.Convert.ToDouble(data[3]));
                        inputs.Add(System.Convert.ToDouble(data[4]));

                        double out1 = Utils.Map(0, 1, -1, 1, System.Convert.ToSingle(data[5]));
                        double out2 = Utils.Map(0, 1, -1, 1, System.Convert.ToSingle(data[6]));
                        outputs.Add(out1);
                        outputs.Add(out2);

                        calcOutputs = ann.Train(inputs, outputs);
                        thisError   = ((Mathf.Pow((float)(outputs[0] - calcOutputs[0]), 2) +
                                        Mathf.Pow((float)(outputs[1] - calcOutputs[1]), 2))) / 2.0f;
                    }
                    sse += thisError;
                }
                trainingProgress = ((float)i / (float)epochs) * 100;
                sse /= lineCount;

                if (lastSSE < sse)   // if sse isnt better reload old one and decrease alpha
                {
                    ann.LoadWeights(currentWeights);
                    ann.alpha = Mathf.Clamp((float)ann.alpha - 0.001f, 0.01f, 0.9f);
                }
                else     // increase alpha
                {
                    ann.alpha = Mathf.Clamp((float)ann.alpha + 0.001f, 0.01f, 0.9f);
                    lastSSE   = sse;
                }
                yield return(null);
            }
            tdf.Close();
        }
        trainingDone = true;
        ann.SaveWeightsToFile();
    }
    IEnumerator LoadTrainingSet()
    {
        string path = Application.dataPath + "/trainingData.txt";
        string line;

        if (File.Exists(path))
        {
            int           lineCount   = File.ReadAllLines(path).Length;
            StreamReader  tdf         = File.OpenText(path);
            List <double> calcOutputs = new List <double>();
            List <double> inputs      = new List <double>();
            List <double> outputs     = new List <double>();

            for (int i = 0; i < epochs; i++)
            {
                //set file pointer to beginning of file
                sse = 0;
                tdf.BaseStream.Position = 0;
                string currentWeights = ann.PrintWeights();
                while ((line = tdf.ReadLine()) != null)
                {
                    string[] data = line.Split(',');
                    //if nothing to be learned ignore this line
                    float thisError = 0;
                    if (Convert.ToDouble(data[5]) != 0 && Convert.ToDouble(data[6]) != 0)   //if the data had user input (active keys pressed)
                    {
                        inputs.Clear();
                        outputs.Clear();
                        inputs.Add(Convert.ToDouble(data[0]));
                        inputs.Add(Convert.ToDouble(data[1]));
                        inputs.Add(Convert.ToDouble(data[2]));
                        inputs.Add(Convert.ToDouble(data[3]));
                        inputs.Add(Convert.ToDouble(data[4]));

                        double o1 = Map(0, 1, -1, 1, Convert.ToSingle(data[5]));
                        outputs.Add(o1);
                        double o2 = Map(0, 1, -1, 1, Convert.ToSingle(data[6]));
                        outputs.Add(o2);

                        calcOutputs = ann.Train(inputs, outputs);
                        thisError   = ((Mathf.Pow((float)(outputs[0] - calcOutputs[0]), 2) + Mathf.Pow((float)(outputs[1] - calcOutputs[1]), 2))) / 2.0f;
                    }
                    sse += thisError;
                }
                trainingProgress = (float)i / (float)epochs;
                sse /= lineCount;

                //if sse isn't better then reload previous set of weights
                //and decrease alpha
                if (lastSSE < sse)
                {
                    ann.LoadWeights(currentWeights);
                    ann.alpha = Mathf.Clamp((float)ann.alpha - 0.001f, 0.01f, 0.9f);
                }
                else
                {
                    ann.alpha = Mathf.Clamp((float)ann.alpha + 0.001f, 0.01f, 0.9f);
                }

                yield return(null);
            }
        }
        trainingDone = true;
        SaveWeightsToFile();
    }