Ejemplo n.º 1
0
        public int ExecuteGraph(IEnumerable <TensorProxy> inputs_it, IEnumerable <TensorProxy> outputs_it)
        {
            Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph");
            TensorProxy[] inputs  = inputs_it.ToArray();
            TensorProxy[] outputs = outputs_it.ToArray();

            // TODO: Can/should we pre-allocate that?
            TFSession.Runner runner = m_session.GetRunner();

            inputs.ToList().ForEach((TensorProxy input) =>
            {
                if (input.Shape.Length == 0)
                {
                    var data = input.Data[0];
                    if (input.DataType == typeof(int))
                    {
                        runner.AddInput(m_graph[input.Name][0], (int)data);
                    }
                    else
                    {
                        runner.AddInput(m_graph[input.Name][0], (float)data);
                    }
                }
                else
                {
                    runner.AddInput(m_graph[input.Name][0], input.DataType == typeof(int) ?
                                    TensorUtils.BarracudaToIntArray(input.Data) :
                                    TensorUtils.BarracudaToFloatArray(input.Data));
                }
            });

            // TODO: better way to pre-allocate this?
            outputs.ToList().ForEach(s => runner.Fetch(s.Name));

            TFStatus status = new TFStatus();

            Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph.RunnerRun");
            var out_tensors = runner.Run(status);

            Profiler.EndSample();

            if (!status.Ok)
            {
                Debug.LogError(status.StatusMessage);
                return(-1);
            }

            Debug.Assert(outputs.Length == out_tensors.Length);

            for (var i = 0; i < outputs.Length; ++i)
            {
                if (outputs[i].Shape.Length == 0)
                {
                    // Handle scalars
                    outputs[i].Data    = new Tensor(1, 1);
                    outputs[i].Data[0] = (float)(int)out_tensors[i].GetValue();
                }
                else
                {
                    outputs[i].Data = TensorUtils.ArrayToBarracuda(out_tensors[i].GetValue() as Array);
                }
            }

            Profiler.EndSample();
            // TODO: create error codes
            return(0);
        }