Example #1
0
        public static TFSessionOptions CreateOptions()
        {
            TFSessionOptions Options = new TFSessionOptions();

            //byte[][] Serialized = new byte[][]
            //{
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x30 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x31 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x32 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x33 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x34 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x35 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x36 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x37 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x38 },
            //    new byte[] { 0x32, 0x5, 0x20, 0x1, 0x2a, 0x1, 0x39 },
            //    new byte[] { 0x32, 0x6, 0x20, 0x1, 0x2a, 0x2, 0x31, 0x30 },
            //    new byte[] { 0x32, 0x6, 0x20, 0x1, 0x2a, 0x2, 0x31, 0x31 },
            //    new byte[] { 0x32, 0x6, 0x20, 0x1, 0x2a, 0x2, 0x31, 0x32 },
            //    new byte[] { 0x32, 0x6, 0x20, 0x1, 0x2a, 0x2, 0x31, 0x33 },
            //    new byte[] { 0x32, 0x6, 0x20, 0x1, 0x2a, 0x2, 0x31, 0x34 },
            //    new byte[] { 0x32, 0x6, 0x20, 0x1, 0x2a, 0x2, 0x31, 0x35 }
            //};
            byte[] Serialized = { 0x32, 0x2, 0x20, 0x1, 0x38, 0x1 };

            TFStatus Stat = new TFStatus();

            unsafe
            {
                fixed(byte *SerializedPtr = Serialized)
                Options.SetConfig(new IntPtr(SerializedPtr), Serialized.Length, Stat);
            }

            return(Options);
        }
Example #2
0
        /// <summary>
        /// Creates a session and graph from a saved session model
        /// </summary>
        /// <returns>On success, this populates the provided <paramref name="graph"/> with the contents of the graph stored in the specified model and <paramref name="metaGraphDef"/> with the MetaGraphDef of the loaded model.</returns>
        /// <param name="sessionOptions">Session options to use for the new session.</param>
        /// <param name="runOptions">Options to use to initialize the state (can be null).</param>
        /// <param name="exportDir">must be set to the path of the exported SavedModel.</param>
        /// <param name="tags">must include the set of tags used to identify one MetaGraphDef in the SavedModel.</param>
        /// <param name="graph">This must be a newly created graph.</param>
        /// <param name="metaGraphDef">On success, this will be populated on return with the contents of the MetaGraphDef (can be null).</param>
        /// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
        /// <remarks>
        /// This function creates a new session using the specified <paramref name="sessionOptions"/> and then initializes
        /// the state (restoring tensors and other assets) using <paramref name="runOptions"/>
        /// </remarks>
        public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, string device, TFStatus status = null)
        {
            if (graph == null)
            {
                throw new ArgumentNullException(nameof(graph));
            }
            if (tags == null)
            {
                throw new ArgumentNullException(nameof(tags));
            }
            if (exportDir == null)
            {
                throw new ArgumentNullException(nameof(exportDir));
            }
            var cstatus = TFStatus.Setup(status);

            unsafe
            {
                var h = TF_LoadSessionFromSavedModelOnDevice(sessionOptions.handle, runOptions == null ? null : runOptions.LLBuffer, exportDir, tags, tags.Length, graph.handle, device, cstatus.handle);

                if (cstatus.CheckMaybeRaise(status))
                {
                    return(new TFSession(h, graph));
                }
            }
            return(null);
        }
Example #3
0
        public void WhileTest()
        {
            using (var j = new WhileTester()) {
                // Create loop: while (input1 < input2) input1 += input2 + 1
                j.Init(2, (TFGraph conditionGraph, TFOutput [] condInputs, out TFOutput condOutput, TFGraph bodyGraph, TFOutput [] bodyInputs, TFOutput [] bodyOutputs, out string name) => {
                    Assert(bodyGraph.Handle != IntPtr.Zero);
                    Assert(conditionGraph.Handle != IntPtr.Zero);

                    var status   = new TFStatus();
                    var lessThan = conditionGraph.Less(condInputs [0], condInputs [1]);

                    Assert(status);
                    condOutput = new TFOutput(lessThan.Operation, 0);

                    var add1        = bodyGraph.Add(bodyInputs [0], bodyInputs [1]);
                    var one         = bodyGraph.Const(1);
                    var add2        = bodyGraph.Add(add1, one);
                    bodyOutputs [0] = new TFOutput(add2, 0);
                    bodyOutputs [1] = bodyInputs [1];

                    name = "Simple1";
                });

                var res = j.Run(-9, 2);

                Assert(3 == (int)res [0].GetValue());
                Assert(2 == (int)res [1].GetValue());
            };
        }
Example #4
0
 static public void Assert(TFStatus status, [CallerMemberName] string caller = null, string message = "")
 {
     if (status.StatusCode != TFCode.Ok)
     {
         throw new Exception($"{caller}: {status.StatusMessage} {message}");
     }
 }
Example #5
0
        //
        // Shows the use of Variable
        //
        void TestVariable()
        {
            Console.WriteLine("Variables");
            var status = new TFStatus();

            using (var g = new TFGraph()) {
                var         initValue = g.Const(1.5);
                var         increment = g.Const(0.5);
                TFOperation init;
                TFOutput    value;
                var         handle = g.Variable(initValue, out init, out value);

                // Add 0.5 and assign to the variable.
                // Perhaps using op.AssignAddVariable would be better,
                // but demonstrating with Add and Assign for now.
                var update = g.AssignVariableOp(handle, g.Add(value, increment));

                var s = new TFSession(g);
                // Must first initialize all the variables.
                s.GetRunner().AddTarget(init).Run(status);
                Assert(status);
                // Now print the value, run the update op and repeat
                // Ignore errors.
                for (int i = 0; i < 5; i++)
                {
                    // Read and update
                    var result = s.GetRunner().Fetch(value).AddTarget(update).Run();

                    Console.WriteLine("Result of variable read {0} -> {1}", i, result [0].GetValue());
                }
            }
        }
Example #6
0
    public static void SVD(float[,] covMat, out float[] s, out float[,] v)
    {
        TFShape shape       = new TFShape(covMat.GetLength(0), covMat.GetLength(1));
        var     reshaped    = covMat.Reshape();
        var     inputTensor = TFTensor.FromBuffer(shape, reshaped, 0, reshaped.Length);

        TFGraph  svdGraph  = new TFGraph();
        TFOutput input     = svdGraph.Placeholder(TFDataType.Float, shape);
        var      svdResult = (ValueTuple <TFOutput, TFOutput, TFOutput>)svdGraph.Svd(input, true);

        var sess   = new TFSession(svdGraph);
        var runner = sess.GetRunner();

        runner.AddInput(input, inputTensor);
        runner.Fetch(svdResult.Item1);
        runner.Fetch(svdResult.Item2);

        TFTensor[] results = runner.Run();
        s = (float[])results[0].GetValue();
        v = (float[, ])results[1].GetValue();
        TFStatus temp = new TFStatus();

        sess.CloseSession(temp);
        sess.DeleteSession(temp);
    }
        public void GetAttributesTest()
        {
            Console.WriteLine("Testing attribute getting");
            var status = new TFStatus();

            using (var graph = new TFGraph()) {
                // Create a graph
                Assert(status);
                var desc = new TFOperationDesc(graph, "Placeholder", "node");
                desc.SetAttrType("dtype", TFDataType.Float);
                long [] ref_shape = new long [3] {
                    1, 2, 3
                };
                desc.SetAttrShape("shape", new TFShape(ref_shape));
                var j = desc.FinishOperation();
                Assert(graph ["node"] != null);

                // Check that the type is correct
                Assert(graph ["node"].GetAttributeType("dtype", status) == TFDataType.Float);
                Assert(status);

                // Check that the shape is correct
                var metadata = graph ["node"].GetAttributeMetadata("shape");
                Assert(Enumerable.SequenceEqual(graph["node"].GetAttributeShape("shape",
                                                                                (int)metadata.TotalSize,
                                                                                status).ToArray(), ref_shape));
                Assert(status);
            };
        }
Example #8
0
        TFOperation Add(TFOperation left, TFOperation right, TFGraph graph, TFStatus status)
        {
            var op = new TFOperationDesc(graph, "AddN", "add");

            op.AddInputs(new TFOutput(left, 0), new TFOutput(right, 0));
            return(op.FinishOperation());
        }
Example #9
0
        static void Main(string[] args)
        {
            // 创建图
            var g = new TFGraph();

            // 创建状态,用于输出操作状态
            TFStatus status = new TFStatus();

            // 定义变量
            var a     = g.VariableV2(TFShape.Scalar, TFDataType.Double);
            var initA = g.Assign(a, g.Const(1.0));

            var b     = g.VariableV2(new TFShape(99), TFDataType.Int32);
            var initB = g.Assign(b, g.Range(g.Const(1), g.Const(100)));

            // 创建会话
            var sess   = new TFSession(g);
            var runner = sess.GetRunner();

            // 初始化变量,并输出操作状态
            runner.AddTarget(initA.Operation, initB.Operation).Run(status);
            Console.WriteLine(status.StatusCode);

            // 并输出计算状态和计算结果
            var result = runner.Fetch(a, b).Run();

            Console.WriteLine(result[0].GetValue());
            Console.WriteLine(string.Join(",", (int[])result[1].GetValue()));
        }
Example #10
0
        internal static unsafe TFTensor CreateString(byte[] buffer)
        {
            if (buffer == null)
            {
                throw new ArgumentNullException(nameof(buffer));
            }
            //
            // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
            // TF_StringEncode-encoded bytes.
            //
            var    size   = TF_StringEncodedSize((UIntPtr)buffer.Length);
            IntPtr handle = TF_AllocateTensor(TFDataType.String, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));

            // Clear offset table
            IntPtr dst = TF_TensorData(handle);

            Marshal.WriteInt64(dst, 0);
            using (var status = new TFStatus())
            {
                fixed(byte *src = &buffer[0])
                {
                    TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte *)(dst + 8), size, status.handle);
                    var ok = status.StatusCode == TFCode.Ok;

                    if (!ok)
                    {
                        return(null);
                    }
                }
            }
            return(new TFTensor(handle));
        }
Example #11
0
        TFOperation Placeholder(TFGraph graph, TFStatus s)
        {
            var desc = new TFOperationDesc(graph, "Placeholder", "feed");

            desc.SetAttrType("dtype", TFDataType.Int32);
            Console.WriteLine("Handle: {0}", desc.Handle);
            var j = desc.FinishOperation();

            Console.WriteLine("FinishHandle: {0}", j.Handle);
            return(j);
        }
Example #12
0
        public void TestSession()
        {
            var status = new TFStatus();

            using (var graph = new TFGraph()) {
                var feed = Placeholder(graph, status);
                var two  = ScalarConst(2, graph, status);
                var add  = Add(feed, two, graph, status);
                Assert(status);

                // Create a session for this graph
                using (var session = new TFSession(graph, status)) {
                    Assert(status);

                    // Run the graph
                    var inputs = new TFOutput [] {
                        new TFOutput(feed, 0)
                    };
                    var input_values = new TFTensor [] {
                        3
                    };
                    var add_output = new TFOutput(add, 0);
                    var outputs    = new TFOutput [] {
                        add_output
                    };

                    var results = session.Run(runOptions: null,
                                              inputs: inputs,
                                              inputValues: input_values,
                                              outputs: outputs,
                                              targetOpers: null,
                                              runMetadata: null,
                                              status: status);
                    Assert(status);
                    var res = results [0];
                    Assert(res.TensorType == TFDataType.Int32);
                    Assert(res.NumDims == 0);                      // Scalar
                    Assert(res.TensorByteSize == (UIntPtr)4);
                    Assert(Marshal.ReadInt32(res.Data) == 3 + 2);

                    // Use runner API
                    var runner = session.GetRunner();
                    runner.AddInput(new TFOutput(feed, 0), 3);
                    runner.Fetch(add_output);
                    results = runner.Run(status: status);
                    res     = results [0];
                    Assert(res.TensorType == TFDataType.Int32);
                    Assert(res.NumDims == 0);                      // Scalar
                    Assert(res.TensorByteSize == (UIntPtr)4);
                    Assert(Marshal.ReadInt32(res.Data) == 3 + 2);
                }
            }
        }
Example #13
0
        TFOperation ScalarConst(TFTensor v, TFGraph graph, TFStatus status, string name = null)
        {
            var desc = new TFOperationDesc(graph, "Const", name == null ? "scalar" : name);

            desc.SetAttr("value", v, status);
            if (status.StatusCode != TFCode.Ok)
            {
                return(null);
            }
            desc.SetAttrType("dtype", TFDataType.Int32);
            return(desc.FinishOperation());
        }
Example #14
0
        TFOperation ScalarConst(int v, TFGraph graph, TFStatus status)
        {
            var desc = new TFOperationDesc(graph, "Const", "scalar");

            desc.SetAttr("value", v, status);
            if (status.StatusCode != TFCode.Ok)
            {
                return(null);
            }
            desc.SetAttrType("dtype", TFDataType.Int32);
            return(desc.FinishOperation());
        }
Example #15
0
        public void AddControlInput()
        {
            Console.WriteLine("Testing AddControlInput for assertions");
            var status = new TFStatus();

            using (var g = new TFGraph())
            {
                var s = new TFSession(g, status);

                TFTensor yes         = true;
                TFTensor no          = false;
                var      placeholder = g.Placeholder(TFDataType.Bool, operName: "boolean");

                var check = new TFOperationDesc(g, "Assert", "assert")
                            .AddInput(placeholder)
                            .AddInputs(placeholder)
                            .FinishOperation();

                var noop = new TFOperationDesc(g, "NoOp", "noop")
                           .AddControlInput(check)
                           .FinishOperation();

                var runner = s.GetRunner();
                runner.AddInput(placeholder, yes);
                runner.AddTarget(noop);

                // No problems when the Assert check succeeds
                runner.Run();

                // Exception thrown by the execution of the Assert node
                try
                {
                    runner = s.GetRunner();
                    runner.AddInput(placeholder, no);
                    runner.AddTarget(noop);
                    runner.Run();
                    throw new Exception("This should have thrown an exception");
                }
                catch (Exception e)
                {
                    Console.WriteLine("Success, got the expected exception when using tensorflow control inputs to assert");
                }
            }
        }
Example #16
0
        public void TestImportGraphDef()
        {
            var      status = new TFStatus();
            TFBuffer graphDef;

            // Create graph with two nodes, "x" and "3"
            using (var graph = new TFGraph())
            {
                Assert(status);
                Placeholder(graph, status);
                Assert(graph["feed"] != null);

                ScalarConst(3, graph, status);
                Assert(graph["scalar"] != null);

                // Export to GraphDef
                graphDef = new TFBuffer();
                graph.ToGraphDef(graphDef, status);
                Assert(status);
            }

            // Import it again, with a prefix, in a fresh graph
            using (var graph = new TFGraph())
            {
                using (var options = new TFImportGraphDefOptions())
                {
                    options.SetPrefix("imported");
                    graph.Import(graphDef, options, status);
                    Assert(status);
                }
                graphDef.Dispose();

                var scalar = graph["imported/scalar"];
                var feed   = graph["imported/feed"];
                Assert(scalar != null);

                Assert(feed != null);

                // Can add nodes to the imported graph without trouble
                Add(feed, scalar, graph, status);
                Assert(status);
            }
        }
Example #17
0
        public void Variables()
        {
            using (TFGraph g = new TFGraph())
                using (TFSession s = new TFSession(g))
                {
                    TFStatus status = new TFStatus();
                    var      runner = s.GetRunner();

                    TFOutput vW, vb, vlinmodel;
                    var      hW           = g.Variable(g.Const(0.3F, TFDataType.Float), out vW);
                    var      hb           = g.Variable(g.Const(-0.3F, TFDataType.Float), out vb);
                    var      hlinearmodel = g.Variable(g.Const(0.0F, TFDataType.Float), out vlinmodel);
                    var      x            = g.Placeholder(TFDataType.Float);

                    var hoplm = g.AssignVariableOp(hlinearmodel, g.Add(g.Mul(vW, x), vb));

                    //init all variable
                    runner
                    .AddTarget(g.GetGlobalVariablesInitializer())
                    .AddTarget(hoplm)
                    .AddInput(x, new float[] { 1F, 2F, 3F, 4F })
                    .Run(status);

                    //now get actual value
                    var result = s.GetRunner()
                                 .Fetch(vlinmodel)
                                 .Run();

                    Assert.NotNull(result);
                    Assert.Equal(1, result.Length);
                    Assert.IsType <float[]>(result[0].GetValue());

                    float[] values = (float[])result[0].GetValue();
                    Assert.NotNull(values);
                    Assert.Equal(4, values.Length);
                    Assert.Equal(0.0F, values[0], 7);
                    Assert.Equal(0.3F, values[1], 7);
                    Assert.Equal(0.6F, values[2], 7);
                    Assert.Equal(0.9F, values[3], 7);
                }
        }
Example #18
0
        //初始化
        private static void InitVariable(TFGraph g, TFSession sess)
        {
            Console.WriteLine("。。。。。。。。。。。。。。。。。初始化。。。。。。。。。。。。。。。。。。");
            TFStatus status = new TFStatus();

            var a = g.VariableV2(TFShape.Scalar, TFDataType.Double);

            var initA = g.Assign(a, g.Const(1.5));

            var b = g.VariableV2(new TFShape(99), TFDataType.Int32);

            //var initB = g.Assign(b, g.Range(g.Const(1),g.Const(5)));
            var initB = g.Assign(b, g.Range(g.Const(1), g.Const(100)));
            var run   = sess.GetRunner();

            run.AddTarget(initA.Operation, initB.Operation).Run(status);
            Console.WriteLine(status.StatusCode);

            var res = run.Fetch(a, b).Run();

            Console.WriteLine(res[0].GetValue());
            Console.WriteLine(string.Join(",", (int[])res[1].GetValue()));
        }
        public void Variables()
        {
            TFStatus status = new TFStatus();
            var      runner = s.GetRunner();

            TFOutput vW, vb, vlinmodel;
            var      hW           = g.Variable(g.Const(0.3F, TFDataType.Float), out vW);
            var      hb           = g.Variable(g.Const(-0.3F, TFDataType.Float), out vb);
            var      hlinearmodel = g.Variable(g.Const(0.0F, TFDataType.Float), out vlinmodel);
            var      x            = g.Placeholder(TFDataType.Float);

            var hoplm = g.AssignVariableOp(hlinearmodel, g.Add(g.Mul(vW, x), vb));

            //init all variable
            runner
            .AddTarget(g.GetGlobalVariablesInitializer())
            .AddTarget(hoplm)
            .AddInput(x, new float[] { 1F, 2F, 3F, 4F })
            .Run(status);

            //now get actual value
            var result = s.GetRunner()
                         .Fetch(vlinmodel)
                         .Run();

            Assert.IsNotNull(result);
            Assert.AreEqual(result.Length, 1);
            Assert.IsInstanceOfType(result[0].GetValue(), typeof(float[]));

            float[] values = (float[])result[0].GetValue();
            Assert.IsNotNull(values);
            Assert.AreEqual(values.Length, 4);
            Assert.AreEqual(values[0], 0.0F, 0.0000001F);
            Assert.AreEqual(values[1], 0.3F, 0.0000001F);
            Assert.AreEqual(values[2], 0.6F, 0.0000001F);
            Assert.AreEqual(values[3], 0.9F, 0.0000001F);
        }
Example #20
0
        public void TestParametersWithIndexes()
        {
            Console.WriteLine("Testing Parameters with indexes");
            var status = new TFStatus();

            using (var g = new TFGraph())
            {
                var s = new TFSession(g, status);

                var split = new TFOperationDesc(g, "Split", "Split")
                            .AddInput(ScalarConst(0, g, status)[0])
                            .AddInput(ScalarConst(new TFTensor(new int[] { 1, 2, 3, 4 }), g, status, "array")[0])
                            .SetAttr("num_split", 2)
                            .FinishOperation();
                var add = new TFOperationDesc(g, "Add", "Add")
                          .AddInput(split[0]).AddInput(split[1]).FinishOperation()[0];

                // fetch using colon separated names
                var fetched = s.GetRunner().Fetch("Split:1").Run()[0];
                var vals    = fetched.GetValue() as int[];
                if (vals[0] != 3 || vals[1] != 4)
                {
                    throw new Exception("Expected the values 3 and 4");
                }

                // Add inputs using colon separated names.
                var t   = new TFTensor(new int[] { 4, 3, 2, 1 });
                var ret = (s.GetRunner().AddInput("Split:0", t).AddInput("Split:1", t).Fetch("Add").Run()).GetValue(0) as TFTensor;
                var val = ret.GetValue() as int[];

                if (val[0] != 8 || val[1] != 6 || val[2] != 4 || val[3] != 2)
                {
                    throw new Exception("Expected 8, 6, 4, 2");
                }
            }
            Console.WriteLine("success");
        }
Example #21
0
        private Tensor GetOpMetadata(TFOperation op)
        {
            TFStatus status = new TFStatus();

            // Query the shape
            long[] shape      = null;
            var    shape_attr = op.GetAttributeMetadata("shape", status);

            if (!status.Ok || shape_attr.TotalSize <= 0)
            {
                Debug.LogWarning("Operation " + op.Name + " does not contain shape attribute or it" +
                                 " doesn't contain valid shape data!");
            }
            else
            {
                if (shape_attr.IsList)
                {
                    throw new NotImplementedException("Querying lists is not implemented yet!");
                }
                else
                {
                    TFStatus s    = new TFStatus();
                    long[]   dims = new long[shape_attr.TotalSize];
                    TF_OperationGetAttrShape(op.Handle, "shape", dims, (int)shape_attr.TotalSize,
                                             s.Handle);
                    if (!status.Ok)
                    {
                        throw new FormatException("Could not query model for op shape (" + op.Name + ")");
                    }
                    else
                    {
                        shape = new long[dims.Length];
                        for (int i = 0; i < shape_attr.TotalSize; ++i)
                        {
                            if (dims[i] == -1)
                            {
                                // we have to use batchsize 1
                                shape[i] = 1;
                            }
                            else
                            {
                                shape[i] = dims[i];
                            }
                        }
                    }
                }
            }

            // Query the data type
            TFDataType type_value = new TFDataType();

            unsafe
            {
                TFStatus s = new TFStatus();
                TF_OperationGetAttrType(op.Handle, "dtype", &type_value, s.Handle);
                if (!s.Ok)
                {
                    Debug.LogWarning("Operation " + op.Name +
                                     ": error retrieving dtype, assuming float!");
                    type_value = TFDataType.Float;
                }
            }

            Tensor.TensorType placeholder_type = Tensor.TensorType.FloatingPoint;
            switch (type_value)
            {
            case TFDataType.Float:
                placeholder_type = Tensor.TensorType.FloatingPoint;
                break;

            case TFDataType.Int32:
                placeholder_type = Tensor.TensorType.Integer;
                break;

            default:
                Debug.LogWarning("Operation " + op.Name +
                                 " is not a float/integer. Proceed at your own risk!");
                break;
            }

            Tensor t = new Tensor
            {
                Data      = null,
                Name      = op.Name,
                Shape     = shape,
                ValueType = placeholder_type
            };

            return(t);
        }
Example #22
0
        public int ExecuteGraph(IEnumerable <Tensor> inputs_it, IEnumerable <Tensor> outputs_it)
        {
            Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph");
            Tensor[] inputs  = inputs_it.ToArray();
            Tensor[] outputs = outputs_it.ToArray();

            // TODO: Can/should we pre-allocate that?
            TFSession.Runner runner = m_session.GetRunner();

            inputs.ToList().ForEach((Tensor input) =>
            {
                if (input.Shape.Length == 0)
                {
                    var data = input.Data.GetValue(0);
                    if (input.DataType == typeof(int))
                    {
                        runner.AddInput(m_graph[input.Name][0], (int)data);
                    }
                    else
                    {
                        runner.AddInput(m_graph[input.Name][0], (float)data);
                    }
                }
                else
                {
                    runner.AddInput(m_graph[input.Name][0], input.Data);
                }
            });

            // TODO: better way to pre-allocate this?
            outputs.ToList().ForEach(s => runner.Fetch(s.Name));

            TFStatus status = new TFStatus();

            Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph.RunnerRun");
            var out_tensors = runner.Run(status);

            Profiler.EndSample();

            if (!status.Ok)
            {
                Debug.LogError(status.StatusMessage);
                return(-1);
            }

            Debug.Assert(outputs.Length == out_tensors.Length);

            for (var i = 0; i < outputs.Length; ++i)
            {
                if (outputs[i].Shape.Length == 0)
                {
                    // Handle scalars
                    outputs[i].Data = Array.CreateInstance(outputs[i].DataType, new long[1] {
                        1
                    });
                    outputs[i].Data.SetValue(out_tensors[i].GetValue(), 0);
                }
                else
                {
                    outputs[i].Data = out_tensors[i].GetValue() as Array;
                }
            }

            Profiler.EndSample();
            // TODO: create error codes
            return(0);
        }
Example #23
0
 public AttributeTest()
 {
     Status = new TFStatus();
     graph  = new TFGraph();
 }
Example #24
0
 public WhileTester()
 {
     status = new TFStatus();
     graph  = new TFGraph();
 }