コード例 #1
0
ファイル: MLEvaluator.cs プロジェクト: erisonliang/anndotnet
        /// <summary>
        /// Test cntk model stored at 'modelPath' against array of image paths
        /// </summary>
        /// <param name="modelPath"></param>
        /// <param name="vector"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        public static List <int> TestModel(string modelPath, string[] imagePaths, DeviceDescriptor device)
        {
            try
            {
                //
                FileInfo fi = new FileInfo(modelPath);
                if (!fi.Exists)
                {
                    throw new Exception($"The '{fi.FullName}' does not exist. Make sure the model is places at this location.");
                }

                //load the model from disk
                var model = Function.Load(fi.FullName, device);
                //get input feature
                var features     = model.Arguments.ToList();
                var labels       = model.Outputs.ToList();
                var stremsConfig = MLFactory.CreateStreamConfiguration(features, labels);
                var mapFile      = "testMapFile";
                File.WriteAllLines(mapFile, imagePaths.Select(x => $"{x}\t0"));

                var testMB = new MinibatchSourceEx(MinibatchType.Image, stremsConfig.ToArray(), features, labels, mapFile, null, 30, false, 0);

                //
                var vars   = features.Union(labels).ToList();
                var retVal = new List <int>();
                var mbSize = imagePaths.Count();
                if (mbSize > 30)
                {
                    mbSize = 30;
                }
                while (true)
                {
                    bool isSweepEnd = false;
                    var  inputMap   = testMB.GetNextMinibatch((uint)mbSize, ref isSweepEnd, vars, device);
                    //prepare data for trainer
                    //var inputMap = new Dictionary<Variable, Value>();
                    //inputMap.Add(features.First(), nextMB.Where(x => x.Key.m_name.Equals(features.First().Name)).Select(x => x.Value.data).FirstOrDefault());


                    var outputMap = new Dictionary <Variable, Value>();
                    outputMap.Add(labels.First(), null);
                    //evaluate model
                    model.Evaluate(inputMap, outputMap, device);
                    var result = outputMap[labels.First()].GetDenseData <float>(labels.First());

                    //extract result
                    foreach (var r in result)
                    {
                        var l = MLValue.GetResult(r);
                        retVal.Add((int)l);
                    }

                    if (/*nextMB.Any(x => x.Value.sweepEnd)*/ isSweepEnd)
                    {
                        break;
                    }
                }


                return(retVal);
            }
            catch (Exception)
            {
                throw;
            }
        }