Exemplo n.º 1
0
    /// <summary>
    /// Saves the given models and potentially all models referenced by the given models into modelBlobs.
    /// </summary>
    /// <param name="models">The models that will be saved, potentially along with the models they have references to.</param>
    /// <param name="saveReferenced">Whether the given models should also save their referenced models.</param>
    /// <returns>All models in form of modelBlobs, xml formatted strings order by id.</returns>
    public static ModelBlobs Save(List <Model> models, bool saveReferenced)
    {
        if (models == null)
        {
            UnityEngine.Debug.LogError("Can't save Models if the given list is null");
            return(null);
        }
        models.Distinct <Model>();
        if (saveReferenced)
        {
            foreach (Model model in models)
            {
                models.AddList <Model>(model.GetReferences());
            }
        }
        models.Distinct <Model>();
        ModelBlobs modelBlobs = new ModelBlobs();

        modelBlobs.Add("manifest", Manifest.Save(models));
        isSerializing = true;
        foreach (Model model in models)
        {
            modelBlobs.Add(model.Id, model.XmlSerializeToString());
        }
        isSerializing = false;
        return(modelBlobs);
    }
Exemplo n.º 2
0
 /// <summary>
 /// Reads the given ModelBlobs and constructs models accordingly. Returns the loaded root Model of type T in the onDone callback.
 /// </summary>
 /// <typeparam name="T">The type of instance that will be outputted in the onDone callback.</typeparam>
 /// <param name="data">The modelBlobs that will be read.</param>
 /// <param name="onStart">Called when the loading process starts.</param>
 /// <param name="onProgress">Called every frame during the loading process. Outputs the progression on a scale of 0.0f to 1.0f.</param>
 /// <param name="onDone">Called when the loading process is done, outputting the loaded instance of the given type.</param>
 /// <param name="onError">Called when the an error occurs during the loading process. Outputs the reason of the error.</param>
 public static void Load <T>(ModelBlobs data, Action onStart, Action <float> onProgress, Action <T> onDone, Action <string> onError) where T : Model
 {
     if (onStart != null)
     {
         onStart();
     }
     if (data == null)
     {
         if (onError != null)
         {
             onError("Failed to load modelBlobs because it is null.");
         }
         return;
     }
     Manifest.LoadAndConstruct <T>(data, onProgress, delegate(T rootModel) {
         foreach (Model model in instances)
         {
             model.CollectReferences();
         }
         if (onDone != null)
         {
             onDone(rootModel);
         }
     }, onError);
 }
Exemplo n.º 3
0
    public void LoadOrbitTest()
    {
        string     path = Application.dataPath + "orbitTestSave.xml";
        ModelBlobs save = ModelBlobs.FromString(File.ReadAllText(path));

        Model.Load(save, OnLoadStart, OnLoadProgress, OnLoadDone, OnLoadError);
    }
Exemplo n.º 4
0
        public static void LoadAndConstruct <T>(ModelBlobs modelBlobs, Action <float> onProgress, Action <T> onDone, Action <string> onError) where T : Model
        {
            if (!modelBlobs.ContainsKey("manifest"))
            {
                if (onError != null)
                {
                    onError("Failed to load modelBlobs as it does not contain a manifest.");
                }
                return;
            }

            instance = modelBlobs["manifest"].XmlDeserializeFromString <Manifest>();
            if (instance == null)
            {
                if (onError != null)
                {
                    onError("Failed to deserialize manifest.");
                }
                return;
            }

            LoadModelDelegate loadModelCallback = delegate(ModelEntry entry) {
                return(modelBlobs[entry.Id].XmlDeserializeFromString(TypeHelper.GetGlobalType(entry.Type)) as Model);
            };

            Coroutiner.Start(ConstructAsync <T>(loadModelCallback, onProgress, onDone, onError));
        }
Exemplo n.º 5
0
 /// <summary>
 /// Reads the given ModelBlobs and constructs models accordingly.
 /// </summary>
 /// <param name="data">The modelBlobs that will be read.</param>
 /// <param name="onStart">Called when the loading process starts.</param>
 /// <param name="onProgress">Called every frame during the loading process. Outputs the progression on a scale of 0.0f to 1.0f.</param>
 /// <param name="onDone">Called when the loading process is done.</param>
 /// <param name="onError">Called when the an error occurs during the loading process. Outputs the reason of the error.</param>
 public static void Load(ModelBlobs data, Action onStart, Action <float> onProgress, Action onDone, Action <string> onError)
 {
     Load <Model>(data, onStart, onProgress, delegate(Model model) { if (onDone != null)
                                                                     {
                                                                         onDone();
                                                                     }
                  }, onError);
 }
Exemplo n.º 6
0
 private void SaveGame()
 {
     if (Config.SAVE_IN_DIRECTORY_ENABLED)
     {
         Model.SaveAll(Config.SAVE_DIRECTORY);
     }
     else
     {
         savedBlobs = Model.SaveAll();
     }
 }
Exemplo n.º 7
0
    public void SaveOrbitTest()
    {
        string path = Application.dataPath + "orbitTestSave.xml";

        if (!File.Exists(path))
        {
            // Function for overriting saves.
        }

        ModelBlobs save = Model.SaveAll();

        File.WriteAllText(path, save.ToString());
    }
Exemplo n.º 8
0
    private static ModelBlobs FromStringArray(string[] splittedData)
    {
        ModelBlobs modelBlobs = new ModelBlobs();
        string     id         = null;

        foreach (string blob in splittedData)
        {
            if (string.IsNullOrEmpty(id))
            {
                id = blob;
            }
            else
            {
                modelBlobs.Add(id, blob);
                id = null;
            }
        }
        return(modelBlobs);
    }
Exemplo n.º 9
0
        public async Task TestAzureTrainer()
        {
            var storageConnectionString       = GetConfiguration("storageConnectionString");
            var inputEventHubConnectionString = GetConfiguration("inputEventHubConnectionString");
            var evalEventHubConnectionString  = GetConfiguration("evalEventHubConnectionString");

            var trainArguments = "--cb_explore_adf --epsilon 0.2 -q ab";

            // register with AppInsights to collect exceptions
            var exceptions = RegisterAppInsightExceptionHook();

            // cleanup blobs
            var blobs = new ModelBlobs(storageConnectionString);
            await blobs.Cleanup();

            var data = GenerateData(100).ToDictionary(d => d.EventId, d => d);

            // start listening for event hub
            using (var trainProcesserHost = new LearnEventProcessorHost())
            {
                await trainProcesserHost.StartAsync(new OnlineTrainerSettingsInternal
                {
                    CheckpointPolicy = new CountingCheckpointPolicy(data.Count),
                    JoinedEventHubConnectionString = inputEventHubConnectionString,
                    EvalEventHubConnectionString   = evalEventHubConnectionString,
                    StorageConnectionString        = storageConnectionString,
                    Metadata = new OnlineTrainerSettings
                    {
                        ApplicationID  = "vwunittest",
                        TrainArguments = trainArguments
                    },
                    EnableExampleTracing     = true,
                    EventHubStartDateTimeUtc = DateTime.UtcNow // ignore any events that arrived before this time
                });

                // send events to event hub
                var eventHubInputClient = EventHubClient.CreateFromConnectionString(inputEventHubConnectionString);
                data.Values.ForEach(c => eventHubInputClient.Send(new EventData(c.JSONAsBytes)
                {
                    PartitionKey = c.Index.ToString()
                }));

                // wait for trainer to checkpoint
                await blobs.PollTrainerCheckpoint(exceptions);

                // download & parse trackback file
                var trackback = blobs.DownloadTrackback();
                Assert.AreEqual(data.Count, trackback.EventIds.Count);

                // train model offline using trackback
                var settings = new VowpalWabbitSettings(trainArguments + $" --id {trackback.ModelId} --save_resume --readable_model offline.json.model.txt -f offline.json.model");
                using (var vw = new VowpalWabbitJson(settings))
                {
                    foreach (var id in trackback.EventIds)
                    {
                        var json = data[id].JSON;

                        var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                        // TODO: validate eval output
                    }

                    vw.Native.SaveModel("offline.json.2.model");
                }

                // download online model
                new CloudBlob(blobs.ModelBlob.Uri, blobs.BlobClient.Credentials).DownloadToFile("online.model", FileMode.Create);

                // validate that the model is the same
                CollectionAssert.AreEqual(
                    File.ReadAllBytes("offline.json.model"),
                    File.ReadAllBytes("online.model"),
                    "Offline and online model differs. Run to 'vw -i online.model --readable_model online.model.txt' to compare");
            }
        }