예제 #1
0
        static public TensorShape GetShapeByName(this Model model, string name)
        {
            foreach (var i in model.inputs)
            {
                if (i.name == name)
                {
                    return(new TensorShape(i.shape));
                }
            }

            TensorShape shape;

            if (ModelAnalyzer.TryGetOutputTensorShape(model, name, out shape))
            {
                return(shape);
            }

            foreach (var l in model.layers)
            {
                foreach (var ds in l.datasets)
                {
                    if (ds.name == name)
                    {
                        return(ds.shape);
                    }
                }
            }

            foreach (var mem in model.memories)
            {
                if (mem.input == name || mem.output == name)
                {
                    return(mem.shape);
                }
            }

            throw new System.Collections.Generic.KeyNotFoundException("Shape " + name + " not found!");
        }
예제 #2
0
        void OnEnable()
        {
            // TODO: investigate perf -- method takes 1s the first time you click on the model in the UI
            var nnModel = target as NNModel;

            if (nnModel == null)
            {
                return;
            }
            if (nnModel.modelData == null)
            {
                return;
            }

            m_Model = ModelLoader.Load(nnModel, verbose: false);
            if (m_Model == null)
            {
                return;
            }

            m_Inputs     = m_Model.inputs.Select(i => i.name).ToList();
            m_InputsDesc = m_Model.inputs.Select(i => $"shape: ({String.Join(",", i.shape)})").ToList();
            m_Outputs    = m_Model.outputs.ToList();

            bool allKnownShapes = true;
            var  inputShapes    = new Dictionary <string, TensorShape>();

            foreach (var i in m_Model.inputs)
            {
                allKnownShapes = allKnownShapes && !i.shape.Contains(-1) && !i.shape.Contains(0);
                if (!allKnownShapes)
                {
                    break;
                }
                inputShapes.Add(i.name, new TensorShape(i.shape));
            }
            if (allKnownShapes)
            {
                m_OutputsDesc = m_Model.outputs.Select(i => {
                    string output = "(-1,-1,-1,-1)";
                    try
                    {
                        TensorShape shape;
                        if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape))
                        {
                            output = shape.ToString();
                        }
                    }
                    catch (Exception e)
                    {
                        Debug.LogError($"Unexpected error while evaluating model output {i}. {e}");
                    }
                    return($"shape: {output}");
                }).ToList();
            }
            else
            {
                m_OutputsDesc = m_Model.outputs.Select(i => "shape: (-1,-1,-1,-1)").ToList();
            }

            m_Memories     = m_Model.memories.Select(i => i.input).ToList();
            m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();

            var layers    = m_Model.layers.Where(i => i.type != Layer.Type.Load);
            var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load);

            m_Layers        = layers.Select(i => i.type.ToString()).ToList();
            m_LayersDesc    = layers.Select(i => i.ToString()).ToList();
            m_Constants     = constants.Select(i => i.type.ToString()).ToList();
            m_ConstantsDesc = constants.Select(i => i.ToString()).ToList();

            m_NumEmbeddedWeights = layers.Sum(l => (long)l.weights.Length).ToString();
            m_NumConstantWeights = constants.Sum(l => (long)l.weights.Length).ToString();

            m_Warnings     = m_Model.Warnings.Select(i => i.LayerName).ToList();
            m_WarningsDesc = m_Model.Warnings.Select(i => i.Message).ToList();
        }
        void OnEnable()
        {
            // TODO: investigate perf -- method takes 1s the first time you click on the model in the UI
            var nnModel = target as NNModel;

            if (nnModel == null)
            {
                return;
            }
            if (nnModel.modelData == null)
            {
                return;
            }

            m_Model = ModelLoader.Load(nnModel, verbose: false);
            if (m_Model == null)
            {
                return;
            }

            m_Inputs     = m_Model.inputs.Select(i => i.name).ToList();
            m_InputsDesc = m_Model.inputs.Select(i => $"shape: ({String.Join(",", i.shape)})").ToList();
            m_Outputs    = m_Model.outputs.ToList();

            bool allKnownShapes = true;
            var  inputShapes    = new Dictionary <string, TensorShape>();

            foreach (var i in m_Model.inputs)
            {
                allKnownShapes = allKnownShapes && !i.shape.Contains(-1) && !i.shape.Contains(0);
                if (!allKnownShapes)
                {
                    break;
                }
                inputShapes.Add(i.name, new TensorShape(i.shape));
            }
            if (allKnownShapes)
            {
                m_OutputsDesc = m_Model.outputs.Select(i => { TensorShape shape; bool sucess = ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape); return(sucess ? $"shape: {shape.ToString()}" : "shape: (-1,-1,-1,-1)"); }).ToList();
            }
            else
            {
                m_OutputsDesc = m_Model.outputs.Select(i => "shape: (-1,-1,-1,-1)").ToList();
            }

            m_Memories     = m_Model.memories.Select(i => i.input).ToList();
            m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();

            m_Layers     = m_Model.layers.Select(i => i.type.ToString()).ToList();
            m_LayersDesc = m_Model.layers.Select(i => i.ToString()).ToList();
            m_NumWeights = m_Model.layers.Sum(l => (long)l.weights.Length).ToString();

            m_Warnings     = m_Model.Warnings.Select(i => i.LayerName).ToList();
            m_WarningsDesc = m_Model.Warnings.Select(i => i.Message).ToList();
        }