private static void DistributedScheduleTestProcess( ICommunicator comm, IList <Gaussian>[] xResults, Gaussian[] shiftResults, Compiler.Graphs.DistributedCommunicationInfo distributedCommunicationInfo, int[][] scheduleForProcess, // [distributedStage][thread][block][i] int[][][][] schedulePerThreadForProcess) { var nodeCountVar = Variable.Observed(0).Named("nodeCount"); Range node = new Range(nodeCountVar).Named("node"); node.AddAttribute(new Sequential() { BackwardPass = true }); var itemCountVar = Variable.Observed(0).Named("itemCount"); Range item = new Range(itemCountVar).Named("item"); var x = Variable.Array <double>(item).Named("x"); x[item] = Variable.GaussianFromMeanAndPrecision(0, 1).ForEach(item); var parentCount = Variable.Observed(default(int[]), node).Named("parentCount"); Range parent = new Range(parentCount[node]).Named("parent"); var indices = Variable.Observed(default(int[][]), node, parent).Named("indices"); indices.SetValueRange(item); var shift = Variable.GaussianFromMeanAndPrecision(0, 1).Named("shift"); shift.AddAttribute(new PointEstimate()); using (Variable.ForEach(node)) { var subArray = Variable.Subarray(x, indices[node]).Named("subArray"); using (Variable.If(parentCount[node] == 1)) { Variable.ConstrainEqualRandom(subArray[0], Gaussian.FromMeanAndVariance(0, 1)); } using (Variable.If(parentCount[node] == 2)) { Variable.ConstrainEqual(subArray[0], subArray[1] + 1); } Variable.ConstrainEqualRandom(shift, new Gaussian(1, 2)); } // this dummy part of the model causes depth cloning to occur Range item2 = new Range(0); var indices2 = Variable.Observed(new int[0], item2).Named("indices2"); var subArray2 = Variable.Subarray(x, indices2); Variable.ConstrainEqual(subArray2[item2], 0.0); var distributedStageCount = Variable.Observed(0).Named("distributedStageCount"); Range distributedStage = new Range(distributedStageCount).Named("stage"); var commVar = Variable.Observed(default(ICommunicator)).Named("comm"); if (schedulePerThreadForProcess != null) { var threadCount = Variable.Observed(0).Named("threadCount"); Range thread = new Range(threadCount).Named("thread"); var blockCountOfDistributedStage = Variable.Observed(default(int[]), distributedStage).Named("blockCount"); Range gameBlock = new Range(blockCountOfDistributedStage[distributedStage]).Named("block"); var gameCountInBlockOfDistributedStage = Variable.Observed(default(int[][][]), distributedStage, thread, gameBlock).Named("GameCountInBlock"); Range gameInBlock = new Range(gameCountInBlockOfDistributedStage[distributedStage][thread][gameBlock]).Named("gameInBlock"); var gamesInBlockOfDistributedStage = Variable.Array(Variable.Array(Variable.Array(Variable.Array <int>(gameInBlock), gameBlock), thread), distributedStage).Named("GamesInBlock"); gamesInBlockOfDistributedStage.ObservedValue = default(int[][][][]); node.AddAttribute(new DistributedSchedule(commVar, gamesInBlockOfDistributedStage)); threadCount.ObservedValue = schedulePerThreadForProcess[0].Length; blockCountOfDistributedStage.ObservedValue = Util.ArrayInit(schedulePerThreadForProcess.Length, stageIndex => schedulePerThreadForProcess[stageIndex][0].Length); gameCountInBlockOfDistributedStage.ObservedValue = Util.ArrayInit(schedulePerThreadForProcess.Length, stageIndex => Util.ArrayInit(schedulePerThreadForProcess[stageIndex].Length, t => Util.ArrayInit(schedulePerThreadForProcess[stageIndex][t].Length, b => schedulePerThreadForProcess[stageIndex][t][b].Length))); gamesInBlockOfDistributedStage.ObservedValue = schedulePerThreadForProcess; } else { var gameCountInLocalBlock = Variable.Observed(new int[0], distributedStage).Named("gameCountInLocalBlock"); Range gameInLocalBlock = new Range(gameCountInLocalBlock[distributedStage]).Named("gameInLocalBlock"); var nodesInLocalBlock = Variable.Observed(new int[0][], distributedStage, gameInLocalBlock).Named("nodesInLocalBlock"); node.AddAttribute(new DistributedSchedule(commVar, nodesInLocalBlock)); gameCountInLocalBlock.ObservedValue = Util.ArrayInit(scheduleForProcess.Length, stageIndex => scheduleForProcess[stageIndex].Length); nodesInLocalBlock.ObservedValue = scheduleForProcess; } var processCount = Variable.Observed(0).Named("processCount"); Range sender = new Range(processCount); var arrayIndicesToSendCount = Variable.Observed(default(int[][]), distributedStage, sender).Named("arrayIndicesToSendCount"); Range arrayIndexToSend = new Range(arrayIndicesToSendCount[distributedStage][sender]); var arrayIndicesToSendVar = Variable.Observed(default(int[][][]), distributedStage, sender, arrayIndexToSend).Named("arrayIndicesToSend"); var arrayIndicesToReceiveCount = Variable.Observed(default(int[][]), distributedStage, sender).Named("arrayIndicesToReceiveCount"); Range arrayIndexToReceive = new Range(arrayIndicesToReceiveCount[distributedStage][sender]); var arrayIndicesToReceiveVar = Variable.Observed(default(int[][][]), distributedStage, sender, arrayIndexToReceive).Named("arrayIndexToReceive"); indices.AddAttribute(new DistributedCommunication(arrayIndicesToSendVar, arrayIndicesToReceiveVar)); distributedStageCount.ObservedValue = scheduleForProcess.Length; commVar.ObservedValue = comm; processCount.ObservedValue = comm.Size; nodeCountVar.ObservedValue = distributedCommunicationInfo.indices.Length; itemCountVar.ObservedValue = distributedCommunicationInfo.arrayLength; parentCount.ObservedValue = distributedCommunicationInfo.indicesCount; indices.ObservedValue = distributedCommunicationInfo.indices; arrayIndicesToSendCount.ObservedValue = distributedCommunicationInfo.arrayIndicesToSendCount; arrayIndicesToSendVar.ObservedValue = distributedCommunicationInfo.arrayIndicesToSend; arrayIndicesToReceiveCount.ObservedValue = distributedCommunicationInfo.arrayIndicesToReceiveCount; arrayIndicesToReceiveVar.ObservedValue = distributedCommunicationInfo.arrayIndicesToReceive; InferenceEngine engine = new InferenceEngine(); //engine.Compiler.UseExistingSourceFiles = true; engine.ModelName = "DistributedScheduleTest" + comm.Rank; engine.ShowProgress = false; engine.NumberOfIterations = 2; engine.OptimiseForVariables = new IVariable[] { x, shift }; var xActual = engine.Infer <IList <Gaussian> >(x); xResults[comm.Rank] = xActual; var shiftActual = engine.Infer <Gaussian>(shift); shiftResults[comm.Rank] = shiftActual; }
private int[][][] parallelScheduleTest(int numThreads, int nodeCount, out IReadOnlyList <Gaussian> xMarginal, out Gaussian shiftMarginal, out IReadOnlyList <int[]> variablesUsedByNode) { int maxParentCount = 1; variablesUsedByNode = GenerateVariablesUsedByNode(nodeCount, maxParentCount); ParallelScheduler ps = new ParallelScheduler(); ps.CreateGraph(variablesUsedByNode); var schedule = ps.GetScheduleWithBarriers(numThreads); ParallelScheduler.WriteSchedule(schedule, true); var schedulePerThread = ps.ConvertToSchedulePerThread(schedule, numThreads); var nodeCountVar = Variable.Observed(nodeCount).Named("nodeCount"); Range node = new Range(nodeCountVar).Named("node"); node.AddAttribute(new Sequential() { BackwardPass = true }); var x = Variable.Array <double>(node).Named("x"); x[node] = Variable.GaussianFromMeanAndPrecision(0, 1).ForEach(node); var parentCount = Variable.Observed(variablesUsedByNode.Select(a => a.Length).ToArray(), node).Named("parentCount"); Range parent = new Range(parentCount[node]).Named("parent"); var indices = Variable.Observed(variablesUsedByNode.ToArray(), node, parent).Named("indices"); var shift = Variable.GaussianFromMeanAndPrecision(0, 1).Named("shift"); shift.AddAttribute(new PointEstimate()); using (Variable.ForEach(node)) { var subArray = Variable.Subarray(x, indices[node]).Named("subArray"); using (Variable.If(parentCount[node] == 1)) { Variable.ConstrainEqualRandom(subArray[0], Gaussian.FromMeanAndVariance(0, 1)); } using (Variable.If(parentCount[node] == 2)) { Variable.ConstrainEqual(subArray[0], subArray[1] + 1); } Variable.ConstrainEqualRandom(shift, new Gaussian(1, 2)); } InferenceEngine engine = new InferenceEngine(); engine.NumberOfIterations = 2; engine.OptimiseForVariables = new IVariable[] { x, shift }; var xExpected = engine.Infer(x); //Console.WriteLine(xExpected); var shiftExpected = engine.Infer(shift); var threadCount = Variable.Observed(0).Named("threadCount"); Range thread = new Range(threadCount).Named("thread"); var blockCount = Variable.Observed(0).Named("blockCount"); Range gameBlock = new Range(blockCount).Named("block"); var gameCountInBlock = Variable.Observed(default(int[][]), thread, gameBlock).Named("GameCountInBlock"); Range gameInBlock = new Range(gameCountInBlock[thread][gameBlock]).Named("gameInBlock"); var gamesInBlock = Variable.Observed(default(int[][][]), thread, gameBlock, gameInBlock).Named("GamesInBlock"); node.AddAttribute(new ParallelSchedule(gamesInBlock)); threadCount.ObservedValue = schedulePerThread.Length; blockCount.ObservedValue = (schedulePerThread.Length == 0) ? 0 : schedulePerThread[0].Length; gameCountInBlock.ObservedValue = Util.ArrayInit(schedulePerThread.Length, t => Util.ArrayInit(schedulePerThread[t].Length, b => schedulePerThread[t][b].Length)); gamesInBlock.ObservedValue = schedulePerThread; var xActual = engine.Infer <IReadOnlyList <Gaussian> >(x); //Debug.WriteLine(xActual); Assert.True(xExpected.Equals(xActual)); var shiftActual = engine.Infer <Gaussian>(shift); Assert.True(shiftExpected.Equals(shiftActual)); xMarginal = xActual; shiftMarginal = shiftActual; return(schedulePerThread); }