Пример #1
0
        public string Evaluate(FloatTensor test_input, FloatTensor test_target, Loss.Loss criterion, int batch_size)
        {
            if (test_input.Shape[0] != test_target.Shape[0])
            {
                throw new InvalidDataException("Input and Target tensors don't seem to have the right dims");
            }

            int[] input_buffer_shape = new int[test_input.Shape.Length];
            input_buffer_shape[0] = batch_size;
            for (int i = 1; i < test_input.Shape.Length; i++)
            {
                input_buffer_shape[i] = test_input.Shape[i];
            }

            FloatTensor test_input_buffer = controller.floatTensorFactory.Create(_shape: input_buffer_shape, _autograd: true);
            FloatTensor test_loss         = controller.floatTensorFactory.Create(_shape: new int[] { 1 });

            int[] target_buffer_shape = new int[test_target.Shape.Length];
            target_buffer_shape[0] = batch_size;
            for (int i = 1; i < test_target.Shape.Length; i++)
            {
                target_buffer_shape[i] = test_target.Shape[i];
            }

            FloatTensor test_target_buffer = controller.floatTensorFactory.Create(_shape: target_buffer_shape, _autograd: true);
            float       loss        = 0;
            int         num_batches = (int)(test_input.Shape[0] / batch_size);
            FloatTensor predictions = controller.floatTensorFactory.Create(test_target.Shape);

            int test_input_batch_offset = batch_size;

            for (int i = 1; i < test_input.Shape.Length; i++)
            {
                test_input_batch_offset *= test_input.Shape[i];
            }

            int test_target_batch_offset = batch_size;

            for (int i = 1; i < test_target.Shape.Length; i++)
            {
                test_target_batch_offset *= test_target.Shape[i];
            }

            for (int batch_i = 0; batch_i < num_batches; batch_i++)
            {
                test_input_buffer.Fill(test_input, starting_offset: batch_i * test_input_batch_offset,
                                       length_to_fill: test_input_batch_offset);
                test_target_buffer.Fill(test_target, starting_offset: batch_i * test_target_batch_offset,
                                        length_to_fill: test_target_batch_offset);
                var        pred       = Forward(test_input_buffer);
                var        batch_loss = criterion.Forward(pred, test_target_buffer);
                List <int> tensor_ids = new List <int> {
                    predictions.Id, pred.Id
                };
                predictions.Fill(pred, starting_offset: 0, length_to_fill: test_target_batch_offset, starting_offset_fill: test_target_batch_offset * batch_i);
                loss += (batch_loss.Data[0] / batch_size);
            }
            test_loss.Fill(loss / num_batches);
            return(test_loss.Id.ToString() + "," + predictions.Id.ToString());
        }
Пример #2
0
        public float FitBatch(int batch_i, int iteration)
        {
            if (((batch_i + 1) * _input_batch_offset) < _input_tensor_origin.Size)
            {
                input_buffer.Fill(_input_tensor_origin, starting_offset: batch_i * _input_batch_offset,
                                  length_to_fill: _input_batch_offset);
                target_buffer.Fill(_target_tensor_origin, starting_offset: batch_i * _target_batch_offset,
                                   length_to_fill: _target_batch_offset);
                var pred = Forward(input_buffer);
                var loss = _criterion.Forward(pred, target_buffer);


                if (cached_ones_grad_for_backprop == null || cached_ones_grad_for_backprop.Size != loss.Size)
                {
                    cached_ones_grad_for_backprop          = loss.createOnesTensorLike();
                    cached_ones_grad_for_backprop.Autograd = false;
                }

                loss.Backward(cached_ones_grad_for_backprop);

                _optimizer.Step(this.input_buffer.Shape[0], iteration);
                return(loss.Data[0]);
            }
            else
            {
                return(0);
            }
        }