示例#1
0
        public override FloatTensor Forward(FloatTensor prediction, FloatTensor target)
        {
            // Note: prediction should be logits, basically pre-softmax. This method applies softmax first.
            // TODO check shapes

            FloatTensor output = ((target.Mul(prediction.Log())).Sum()).Neg();

            return(output);
        }
示例#2
0
        protected override FloatTensor Forward(FloatTensor prediction, FloatTensor target)
        {
            // Note: prediction should be logits, basically pre-softmax. This method applies softmax first.
            // TODO check shapes

            FloatTensor softmax = Functional.Softmax(prediction);
            FloatTensor output  = ((target.Mul(softmax.Log1p())).Sum()).Mul(-1);

            return(output);
        }
示例#3
0
        public void ElementwiseMultiplicationUnequalShapes()
        {
            float[] data1   = { 1, 2, 3, 4, 5, 6 };
            int[]   shape1  = { 2, 3 };
            var     tensor1 = new FloatTensor(data1, shape1);

            float[] data2   = { 1, 2, 3, 4, 5, 6 };
            int[]   shape2  = { 3, 2 };
            var     tensor2 = new FloatTensor(data2, shape2);

            Assert.That(() => tensor1.Mul(tensor2),
                        Throws.TypeOf <InvalidOperationException>());
        }
示例#4
0
        public override FloatTensor Forward(FloatTensor input)
        {
            if (_mask_source == null || input.Size != _mask_source.Size)
            {
                _mask_source = input.emptyTensorCopy(hook_graph: false);
                _mask_source.Fill(1 - rate, inline: true);
            }
            ;

            FloatTensor output = input.Mul(_mask_source.SampleMask());

            activation = output.Id;
            return(output);
        }
示例#5
0
        public void ScalarMultiplication()
        {
            float[] data1   = { float.MinValue, -10, -1.5f, 0, 1.5f, 10, 20, float.MaxValue };
            int[]   shape1  = { 2, 4 };
            var     tensor1 = new FloatTensor(data1, shape1);
            var     tensor2 = new FloatTensor(data1, shape1);

            // Test multiplication by 0
            float scalar = 0;
            var   result = tensor1.Mul(scalar);

            for (int i = 0; i < tensor1.Size; i++)
            {
                Assert.AreEqual(tensor2.Data [i] * scalar, result.Data [i]);
            }

            // Test multiplication by positive
            tensor1 = new FloatTensor(data1, shape1);
            scalar  = 99;
            result  = tensor1.Mul(scalar);

            for (int i = 0; i < tensor1.Size; i++)
            {
                Assert.AreEqual(tensor2.Data [i] * scalar, result.Data [i]);
            }

            // Test multiplication by negative
            tensor1 = new FloatTensor(data1, shape1);
            scalar  = -99;
            result  = tensor1.Mul(scalar);

            for (int i = 0; i < tensor1.Size; i++)
            {
                Assert.AreEqual(tensor2.Data [i] * scalar, result.Data [i]);
            }

            // Test multiplication by decimal
            tensor1 = new FloatTensor(data1, shape1);
            scalar  = 0.000001f;
            result  = tensor1.Mul(scalar);

            for (int i = 0; i < tensor1.Size; i++)
            {
                Assert.AreEqual(tensor2.Data [i] * scalar, result.Data [i]);
            }
        }
示例#6
0
        public void ElementwiseMultiplication()
        {
            float[] data1   = { float.MinValue, -10, -1.5f, 0, 1.5f, 10, 20, float.MaxValue };
            int[]   shape1  = { 2, 4 };
            var     tensor1 = new FloatTensor(data1, shape1);

            float[] data2   = { float.MinValue, -10, -1.5f, 0, 1.5f, 10, 20, float.MaxValue };
            int[]   shape2  = { 2, 4 };
            var     tensor2 = new FloatTensor(data2, shape2);

            var tensorMult = tensor1.Mul(tensor2);

            for (int i = 0; i < tensorMult.Size; i++)
            {
                float current = tensor1.Data [i] * tensor2.Data [i];
                Assert.AreEqual(tensorMult.Data [i], current);
            }
        }