Ejemplo n.º 1
0
    public NeuralNetwork(
        NNSaveData data,
        int[] layerLengths,
        float learnRate,
        int batchSize,
        float dropoutKeepRate,
        float weightDecayRate,
        ActivisionFunctionType activisionType,
        ActivisionFunctionType activisionTypeOutput,
        CostFunctionType costType,
        InitializationType initializationType,
        float activisionCoeffitient,
        int currentDropoutSeed,
        int dropoutSeed

        )
    {
        m_layerLengths          = layerLengths;
        m_layerCount            = layerLengths.Length;
        m_learnRate             = learnRate;
        m_activisionCoeffitient = activisionCoeffitient;
        m_dropoutKeepRate       = dropoutKeepRate;
        m_weightDecayRate       = weightDecayRate;


        if (batchSize <= 0)
        {
            Debug.Log("Info: batch size was <= 0 (" + batchSize + "). It was set to a default value of 1!");
            m_batchSize = 1;
        }
        m_batchSize = batchSize;

        m_activisionFunctionType       = activisionType;
        m_activisionFunctionTypeOutput = activisionTypeOutput;
        m_costFunctionType             = costType;
        m_initializationType           = initializationType;

        SetActivisionFunction(activisionType);

        InitializeBiases(data.m_biases, 0);
        InitializeWeights(data.m_weights, 0);

        InitializeBatch();
        InitializeBackPropagation();

        //m_activisionFunctionType = ActivisionFunctionType.Sigmoid;
    }
Ejemplo n.º 2
0
    public NNSaveData SaveData()
    {
        // biases
        JaggedArrayContainer[] finalBiasArray = new JaggedArrayContainer[m_biases.Length];
        for (int layerIndex = 0; layerIndex < finalBiasArray.Length; layerIndex++)
        {
            float[] biasData = new float[m_biases[layerIndex].m_rowCountY];
            for (int nodeIndex = 0; nodeIndex < biasData.Length; nodeIndex++)
            {
                biasData[nodeIndex] = m_biases[layerIndex].m_data[nodeIndex][0];
            }
            finalBiasArray[layerIndex] = new JaggedArrayContainer(biasData);
        }

        // weights
        JaggedArrayContainer[] finalWeightArray = new JaggedArrayContainer[m_weights.Length];
        for (int layerIndex = 0; layerIndex < finalWeightArray.Length; layerIndex++)
        {
            JaggedArrayContainer[] weightArray = new JaggedArrayContainer[m_weights[layerIndex].m_rowCountY];
            for (int nodeIndex = 0; nodeIndex < weightArray.Length; nodeIndex++)
            {
                float[] weightData = new float[m_weights[layerIndex].m_columnCountX];
                for (int weightIndex = 0; weightIndex < weightData.Length; weightIndex++)
                {
                    weightData[weightIndex] = m_weights[layerIndex].m_data[nodeIndex][weightIndex];
                }
                weightArray[nodeIndex]           = new JaggedArrayContainer();
                weightArray[nodeIndex].dataFloat = weightData;
            }
            finalWeightArray[layerIndex] = new JaggedArrayContainer(weightArray);
        }

        NNSaveData data = new NNSaveData
        {
            m_biases             = finalBiasArray,
            m_weights            = finalWeightArray,
            m_initDropoutSeed    = m_initDropoutSeed,
            m_currentDropoutSeed = m_currentDropoutSeed
        };


        m_currentDropoutSeed = data.m_currentDropoutSeed;
        m_initDropoutSeed    = data.m_initDropoutSeed;
        return(data);
    }
Ejemplo n.º 3
0
 public void LoadData(NNSaveData data)
 {
 }
    private void ManageSize()
    {
        if (!Input.GetKeyDown(m_keyCodeIncreaseSize))
        {
            return;
        }

        NNSaveData data = m_network.SaveData();

        int oldEnemyWidth        = m_sampleManager.GetScreenshotScript().GetInputLayerLengthEnemy(0, 0);
        int oldPlayerHeightPixel = m_sampleManager.GetScreenshotScript().GetInputLayerLengthPlayer(0, 0);

        m_sampleManager.GetScreenshotScript().SetCaptureHeight(m_sampleManager.GetScreenshotScript().GetCaptureHeight() * 2);
        m_sampleManager.GetScreenshotScript().SetCaptureWidth(m_sampleManager.GetScreenshotScript().GetCaptureWidth() * 2);
        m_sampleManager.GetScreenshotScript().SetCaptureSize();
        m_sampleManager.GetScreenshotScript().SetCaptureSizesPlayer(m_sampleManager.GetScreenshotScript().GetCaptureWidth());
        m_layerLengths[0] = m_sampleManager.GetScreenshotScript().GetInputLayerLengthTotal(0, 0);

        int newPlayerHeightPixel = m_sampleManager.GetScreenshotScript().GetInputLayerLengthPlayer(0, 0);
        int width = m_sampleManager.GetScreenshotScript().GetCaptureWidth();

        JaggedArrayContainer[] newWeights = new JaggedArrayContainer[data.m_biases[0].dataFloat.Length];

        for (int nodeIndex = 0; nodeIndex < data.m_biases[0].dataFloat.Length; nodeIndex++)
        {
            bool addIndex = true;
            JaggedArrayContainer weights2 = new JaggedArrayContainer(m_layerLengths[0], 0);
            int index = 0;
            for (int weightIndex = 0; weightIndex < data.m_weights[0].array[0].dataFloat.Length; weightIndex++)
            {
                if (weightIndex < oldEnemyWidth)
                {
                    if (weightIndex % (width / 2) == 0 && weightIndex != 0)
                    {
                        index += width + 2;
                    }
                    else if (weightIndex != 0)
                    {
                        index += 2;
                    }
                    int[] indices = { index, index + 1, index + width, index + width + 1 };

                    foreach (int i in indices)
                    {
                        weights2.dataFloat[i] = data.m_weights[0].array[nodeIndex].dataFloat[weightIndex] * 0.25f;
                    }
                }
                else
                {
                    if (2 * oldPlayerHeightPixel > newPlayerHeightPixel)
                    {
                        if (weightIndex % (width / 2) == 0 && weightIndex != 0)
                        {
                            index += width + 2;
                        }
                        else if (weightIndex != 0)
                        {
                            index = index + 2;
                        }
                        int[] indices = { index, index + 1, index + width, index + width + 1 };
                        foreach (int i in indices)
                        {
                            weights2.dataFloat[i] = data.m_weights[0].array[nodeIndex].dataFloat[weightIndex] * 0.25f;
                        }
                    }
                    else
                    {
                        if (addIndex)
                        {
                            index   += width + 2;
                            addIndex = false;
                        }

                        //Debug.Log(weightIndex + ": (" + index + ")");
                        weights2.dataFloat[index]     = data.m_weights[0].array[nodeIndex].dataFloat[weightIndex] * 0.5f;
                        weights2.dataFloat[index + 1] = data.m_weights[0].array[nodeIndex].dataFloat[weightIndex] * 0.5f;
                        index += 2;
                    }
                    //else
                    //Debug.Log("Warning!");

                    newWeights[nodeIndex] = weights2;
                }
            }
        }

        data.m_weights[0].array = newWeights;

        NNCSaveData containerData = new NNCSaveData
        {
            m_trainingData             = m_trainingManager.SaveData(),
            m_sampleData               = m_sampleManager.SaveData(),
            m_visuilizationNetworkData = m_visualizationNetwork.SaveData(),
            //m_visuilizationSampleData = m_visualizationSample.SaveData(),
            m_networkData = data,


            m_layerLengths         = m_layerLengths,
            m_dataFileName         = m_dataFileName,
            m_activisionType       = m_activisionType,
            m_activisionTypeOutput = m_activisionTypeOutput,
            m_costType             = m_costType,
            m_initializationType   = m_initializationType,
            m_activisionConstant   = m_activisionConstant
        };

        LoadContainer(containerData);
    }