Esempio n. 1
0
        /// <summary>
        /// Get model tensor shape by name
        /// </summary>
        /// <param name="model">Model</param>
        /// <param name="name">Tensor name</param>
        /// <returns>Tensor shape</returns>
        /// <exception cref="KeyNotFoundException"></exception>
        static public TensorShape?GetShapeByName(this Model model, string name)
        {
            foreach (var i in model.inputs)
            {
                if (i.name == name)
                {
                    return(new TensorShape(i.shape));
                }
            }

            TensorShape shape;

            if (ModelAnalyzer.TryGetOutputTensorShape(model, name, out shape))
            {
                return(shape);
            }

            foreach (var l in model.layers)
            {
                foreach (var ds in l.datasets)
                {
                    if (ds.name == name)
                    {
                        return(ds.shape);
                    }
                }
            }

            foreach (var mem in model.memories)
            {
                if (mem.input == name || mem.output == name)
                {
                    return(mem.shape);
                }
            }

            throw new System.Collections.Generic.KeyNotFoundException("Shape " + name + " not found!");
        }
        void OnEnable()
        {
            // TODO: investigate perf -- method takes 1s the first time you click on the model in the UI
            var nnModel = target as NNModel;

            if (nnModel == null)
            {
                return;
            }
            if (nnModel.modelData == null)
            {
                return;
            }

            m_Model = ModelLoader.Load(nnModel, verbose: false, skipWeights: true);
            if (m_Model == null)
            {
                return;
            }

            m_Inputs     = m_Model.inputs.Select(i => i.name).ToList();
            m_InputsDesc = m_Model.inputs.Select(i => $"shape: ({String.Join(",", i.shape)})").ToList();
            m_Outputs    = m_Model.outputs.ToList();

            bool allKnownShapes = true;
            var  inputShapes    = new Dictionary <string, TensorShape>();

            foreach (var i in m_Model.inputs)
            {
                allKnownShapes = allKnownShapes && !i.shape.Contains(-1) && !i.shape.Contains(0);
                if (!allKnownShapes)
                {
                    break;
                }
                inputShapes.Add(i.name, new TensorShape(i.shape));
            }
            if (allKnownShapes)
            {
                m_OutputsDesc = m_Model.outputs.Select(i => {
                    string output = "(-1,-1,-1,-1)";
                    try
                    {
                        TensorShape shape;
                        if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape))
                        {
                            output = shape.ToString();
                        }
                    }
                    catch (Exception e)
                    {
                        Debug.LogError($"Unexpected error while evaluating model output {i}. {e}");
                    }
                    return($"shape: {output}");
                }).ToList();
            }
            else
            {
                m_OutputsDesc = m_Model.outputs.Select(i => "shape: (-1,-1,-1,-1)").ToList();
            }

            m_Memories     = m_Model.memories.Select(i => i.input).ToList();
            m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();

            var layers    = m_Model.layers.Where(i => i.type != Layer.Type.Load);
            var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load);

            m_Layers        = layers.Select(i => i.type.ToString()).ToList();
            m_LayersDesc    = layers.Select(i => i.ToString()).ToList();
            m_Constants     = constants.Select(i => i.type.ToString()).ToList();
            m_ConstantsDesc = constants.Select(i => i.ToString()).ToList();

            m_NumEmbeddedWeights = layers.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
            m_NumConstantWeights = constants.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));

            // weights are not loaded for UI, recompute size
            m_TotalWeightsSizeInBytes = 0;
            for (var l = 0; l < m_Model.layers.Count; ++l)
            {
                for (var d = 0; d < m_Model.layers[l].datasets.Length; ++d)
                {
                    m_TotalWeightsSizeInBytes += m_Model.layers[l].datasets[d].length;
                }
            }

            m_Warnings     = m_Model.Warnings.Select(i => i.LayerName).ToList();
            m_WarningsDesc = m_Model.Warnings.Select(i => i.Message).ToList();
        }