static void RunRepeat() { RepeatObservationEnvironment.SanityCheck(); var random = new Random(); const int RepeatAgents = 3; ndarray RepeatRandomActionSampler() => ndarray.FromList(Range(0, RepeatAgents) .Select(_ => (float)random.NextDouble() * 2 - 1) .ToList()) .reshape(new int[] { RepeatAgents, 1 }) .AsArray <float>(); SoftActorCritic.SoftActorCritic.Run(new RepeatObservationEnvironment(RepeatAgents), agentGroup: null, actorCriticFactory: ActorCriticFactory, observationDimensions: 1, actionDimensions: 1, actionLimit: 1, feedFrames: 1, hiddenSizes: new int[] { 32 }, maxEpisodeLength: 256, replaySize: 1024 * 1024 / 8, learningRate: 2e-4f, startSteps: 100, actionSampler: RepeatRandomActionSampler); }
public static void SanityCheck() { // sanity check var env = new RepeatObservationEnvironment(agents: 3); env.Reset(); env.Step(); for (int episode = 0; episode < 100; episode++) { var observation = (ndarray)env.GetStepResult(null).Item1.obs[0]; env.SetActions(null, observation); env.Step(); var step = env.GetStepResult(null); var success = step.Item1.reward >= 1.99f; bool allPass = success.all(); Trace.Assert(allPass); } }