コード例 #1
0
            public static Tensor <Type> Create(Tensor <Type>[] inputs, Scalar <Type> .Var[] vars, Scalar <Type> abstraction)
            {
                if (inputs.Length != vars.Length)
                {
                    throw new ArgumentException("Need one captured by inputs");
                }
                var shape = GetShape(inputs);

                // guw: the following code aims at simplifying abstraction like (_x, _y) => _x
                // for now I haven't see a case where this happens, but it might in the future
#if NOT_USED
                var varsInAbstraction = abstraction.FindAll <Scalar <Type> .Var>();
                // easy case, everybody is to be removed
                if (varsInAbstraction.Count == 0)
                {
                    return(Op.Const(abstraction, shape));
                }

                var varsToRemove = vars.Where(v => !varsInAbstraction.Contains(v)).ToList();

                if (varsToRemove.Count > 0)
                {
                    inputs = inputs.Where((_, i) => !varsToRemove.Contains(vars[i])).ToArray();
                    vars   = vars.Where(v => !varsToRemove.Contains(v)).ToArray();
                }
                if (inputs.Length == 0)
                {
                    return(Op.Const(abstraction, shape));
                }
#endif

                // As `Deindexing` operations are heavy we try to apply the lambda before deindexing.
                var deindexing = inputs.All(x => x is Deindexing <Type>);
                if (deindexing)
                {
                    var indices        = (inputs.First() as Deindexing <Type>).Indices;
                    var deindexedShape = inputs.First().Shape;
                    if (inputs.All(x => (x as Deindexing <Type>).Indices == indices && x.Shape.WillEqualTo(deindexedShape)))
                    {
                        // TODO not covered
                        var nary = Create(
                            inputs.Select(x => (x as Deindexing <Type>).Content).ToArray(),
                            vars,
                            abstraction
                            );
                        return(Deindexing <Type> .Create(nary, deindexedShape, indices));
                    }
                }

                return(new Elementwise(inputs, vars, abstraction, shape));
            }
コード例 #2
0
ファイル: Tensor.cs プロジェクト: stuarthillary/TheaNet
        public static Tensor <Type> operator -(Tensor <Type> x)
        {
            switch (x)
            {
            case Elementwise unary when unary.Abstraction is Scalars.Neg <Type> :
                return(unary.Inputs[0]);           // -(-x) = x

            case Elementwise binary when binary.Abstraction is Scalars.Sub <Type> :
                return(binary.Inputs[1] - binary.Inputs[0]);       // -(x - y) = y - x

            case Fill <Type> fill:
                return(Op.Const(-fill.x, fill.Shape));

            case OneHot <Type> oneHot:
                return(Op.OneHot(oneHot.Shape, oneHot.Index, -oneHot.Content));

            default:
                return(Op.Apply(x, _x => - _x));
            }
        }