Beispiel #1
0
        /// <summary>
        /// Scripted importer callback
        /// </summary>
        /// <param name="ctx">Asset import context</param>
        public override void OnImportAsset(AssetImportContext ctx)
        {
            var model = File.ReadAllBytes(ctx.assetPath);

            // Analyze model and send analytics if enabled
            var nnModel = ModelLoader.Load(ctx.assetPath, skipWeights: true);

            BarracudaAnalytics.SendBarracudaImportEvent(null, nnModel);

            var assetData = ScriptableObject.CreateInstance <NNModelData>();

            assetData.Value     = model;
            assetData.name      = "Data";
            assetData.hideFlags = HideFlags.HideInHierarchy;

            var asset = ScriptableObject.CreateInstance <NNModel>();

            asset.modelData = assetData;
            ctx.AddObjectToAsset("main obj", asset, LoadIconTexture());
            ctx.AddObjectToAsset("model data", assetData);

            ctx.SetMainObject(asset);
        }
        void OnEnable()
        {
            // TODO: investigate perf -- method takes 1s the first time you click on the model in the UI
            var nnModel = target as NNModel;

            if (nnModel == null)
            {
                return;
            }
            if (nnModel.modelData == null)
            {
                return;
            }

            m_Model = ModelLoader.Load(nnModel, verbose: false, skipWeights: true);
            if (m_Model == null)
            {
                return;
            }

            m_Inputs     = m_Model.inputs.Select(i => i.name).ToList();
            m_InputsDesc = m_Model.inputs.Select(i => $"shape: ({String.Join(",", i.shape)})").ToList();
            m_Outputs    = m_Model.outputs.ToList();

            bool allKnownShapes = true;
            var  inputShapes    = new Dictionary <string, TensorShape>();

            foreach (var i in m_Model.inputs)
            {
                allKnownShapes = allKnownShapes && !i.shape.Contains(-1) && !i.shape.Contains(0);
                if (!allKnownShapes)
                {
                    break;
                }
                inputShapes.Add(i.name, new TensorShape(i.shape));
            }
            if (allKnownShapes)
            {
                m_OutputsDesc = m_Model.outputs.Select(i => {
                    string output = "(-1,-1,-1,-1)";
                    try
                    {
                        TensorShape shape;
                        if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape))
                        {
                            output = shape.ToString();
                        }
                    }
                    catch (Exception e)
                    {
                        Debug.LogError($"Unexpected error while evaluating model output {i}. {e}");
                    }
                    return($"shape: {output}");
                }).ToList();
            }
            else
            {
                m_OutputsDesc = m_Model.outputs.Select(i => "shape: (-1,-1,-1,-1)").ToList();
            }

            m_Memories     = m_Model.memories.Select(i => i.input).ToList();
            m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();

            var layers    = m_Model.layers.Where(i => i.type != Layer.Type.Load);
            var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load);

            m_Layers        = layers.Select(i => i.type.ToString()).ToList();
            m_LayersDesc    = layers.Select(i => i.ToString()).ToList();
            m_Constants     = constants.Select(i => i.type.ToString()).ToList();
            m_ConstantsDesc = constants.Select(i => i.ToString()).ToList();

            m_NumEmbeddedWeights = layers.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
            m_NumConstantWeights = constants.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));

            // weights are not loaded for UI, recompute size
            m_TotalWeightsSizeInBytes = 0;
            for (var l = 0; l < m_Model.layers.Count; ++l)
            {
                for (var d = 0; d < m_Model.layers[l].datasets.Length; ++d)
                {
                    m_TotalWeightsSizeInBytes += m_Model.layers[l].datasets[d].length;
                }
            }

            m_Warnings     = m_Model.Warnings.Select(i => i.LayerName).ToList();
            m_WarningsDesc = m_Model.Warnings.Select(i => i.Message).ToList();
        }