示例#1
0
        public Task TrainAsync(MLPointCollection points)
        {
            return(Task.Factory.StartNew(() =>
            {
                _modelDispatcher.Invoke(() =>
                {
                    Log.Info("Training");
                    var xydata = _dataGenerator.GetPointsXYData(points);

                    var x = np.array(xydata.x.ToArray());
                    var y = np.array(xydata.y.ToArray());
                    var count = (int)(xydata.x.Count / (decimal)xydata.dataItemsPerX);
                    x = x.reshape(count, xydata.dataItemsPerX);
                    y = y.reshape(count, 3);

                    //Tensorflow.InvalidArgumentError: 'In[0] mismatch In[1] shape: 28 vs. 1120: [5,28] [1120,60] 0 0'

                    /*_model = keras.Sequential(
                     *  new List<ILayer>
                     *  {
                     *      new Flatten(new FlattenArgs
                     *      {
                     *          InputShape = new TensorShape(xydata.dataItemsPerX)
                     *      }),
                     *      //keras.layers.Flatten(),
                     *      keras.layers.Dense(xydata.dataItemsPerX, activation: "relu"),//, input_shape: new TensorShape(-1, xydata.dataItemsPerX)),
                     *      keras.layers.Dense(60, activation: "relu"),
                     *      keras.layers.Dense(40, activation: "relu"),
                     *      keras.layers.Dense(3, activation: "softmax"),
                     *  });
                     *
                     *                   _model.compile(keras.optimizers.SGD(0.01F), keras.losses.CategoricalCrossentropy(from_logits: true),
                     */

                    var numberOfClasses = 3;
                    _model = keras.Sequential(
                        new List <ILayer>
                    {
                        new Flatten(new FlattenArgs
                        {
                            InputShape = new TensorShape(xydata.dataItemsPerX)
                        }),
                        //keras.layers.Flatten(),
                        keras.layers.Dense(xydata.dataItemsPerX, activation: "relu"),    //, input_shape: new TensorShape(-1, xydata.dataItemsPerX)),
                        keras.layers.Dropout(0.2F),
                        keras.layers.Dense(12, activation: "relu"),
                        keras.layers.Dropout(0.2F),
                        keras.layers.Dense(6, activation: "relu"),
                        keras.layers.Dense(numberOfClasses, activation: "softmax"),
                    });

                    //var loss = new SGD(0.05F);
                    //var optimiser = new SparseCategoricalCrossentropy();
                    //model.compile(loss, optimiser, new[] { "accuracy" });
                    //model.compile(new SGD(0.1F), new SparseCategoricalCrossentropy(), new[] { "accuracy" });

                    // logits and labels must have the same first dimension, got logits shape [5,3] and labels shape [15]'
                    _model.compile(
                        keras.optimizers.Adam(0.01F),
                        keras.losses.CategoricalCrossentropy(),
                        new[] { "acc" });

                    //here // SparseCategoricalCrossentropy?  Validation set? More generated data?

                    _model.fit(x, y, 5, 100, 1, validation_split: 0.1F);
                    Log.Info("Training complete");
                });
            }));//, TaskCreationOptions.LongRunning);
        }
示例#2
0
        public (List <float> x, List <float> y, List <Candle> candlesUsed) GetPointXYData(MLPoint p, MLPointCollection points)
        {
            var candlesLookup = GetCandlesLookup(points);

            return(GetPointXYData(p, candlesLookup));
        }