Exemple #1
0
 internal void ShowBrowser(TransformerChain tc, string folder, string name)
 {
     if (InferenceEngine.Visualizer?.TransformerChainVisualizer != null)
     {
         InferenceEngine.Visualizer.TransformerChainVisualizer.VisualizeTransformerChain(tc, folder, name);
     }
 }
        /// <summary>
        /// Get the abstract syntax tree for the generated code.
        /// </summary>
        /// <param name="itd"></param>
        /// <param name="method"></param>
        /// <param name="inputAttributes"></param>
        /// <returns></returns>
        internal List <ITypeDeclaration> GetTransformedDeclaration(ITypeDeclaration itd, MethodBase method, AttributeRegistry <object, ICompilerAttribute> inputAttributes)
        {
            TransformerChain        tc = ConstructTransformChain(method);
            List <ITypeDeclaration> output;

            try
            {
                Compiling?.Invoke(this, new CompileEventArgs());
                bool trackTransform = (BrowserMode != BrowserMode.Never);
                List <TransformError> warnings;
                output = tc.TransformToDeclaration(itd, inputAttributes, trackTransform, ShowProgress, out warnings, CatchExceptions, TreatWarningsAsErrors);
                OnCompiled(new CompileEventArgs()
                {
                    Warnings = warnings
                });
                if (BrowserMode == BrowserMode.Always)
                {
                    ShowBrowser(tc, GeneratedSourceFolder, output[0].Name);
                }
                else if (BrowserMode == BrowserMode.WriteFiles)
                {
                    tc.WriteAllOutputs(Path.Combine(GeneratedSourceFolder, output[0].Name + " Transforms"));
                }
                if (ShowSchedule && InferenceEngine.Visualizer?.TaskGraphVisualizer != null)
                {
                    foreach (CodeTransformer ct in tc.transformers)
                    {
                        DeadCodeTransform bst = ct.Transform as DeadCodeTransform;
                        //SchedulingTransform bst = ct.Transform as SchedulingTransform;
                        if (bst != null)
                        {
                            foreach (ITypeDeclaration itd2 in ct.transformMap.Values)
                            {
                                InferenceEngine.Visualizer.TaskGraphVisualizer.VisualizeTaskGraph(itd2, (BasicTransformContext)bst.Context);
                            }
                        }
                    }
                }
            }
            catch (TransformFailedException ex)
            {
                OnCompiled(new CompileEventArgs()
                {
                    Exception = ex
                });
                if (BrowserMode != BrowserMode.Never)
                {
                    ShowBrowser(tc, GeneratedSourceFolder, itd.Name);
                }
                throw new CompilationFailedException(ex.Results, ex.Message);
            }
            return(output);
        }
Exemple #3
0
        /// <summary>
        /// Construct the transform chain for the given method
        /// </summary>
        /// <param name="method">The method</param>
        /// <returns></returns>
        private TransformerChain ConstructTransformChain(MethodBase method)
        {
            if (algorithm == null)
            {
                throw new InferCompilerException("No algorithm was specified, please specify one before compiling.");
            }
            TransformerChain tc = new TransformerChain();

            //if (args != null) tc.AddTransform(new ParameterInsertionTransform(method,args));
            if (UnrollLoops)
            {
                tc.AddTransform(new LoopUnrollingTransform(method));
            }
            bool useVariableTransform = !(algorithm is GibbsSampling);
            bool useDepthCloning      = useVariableTransform;

            tc.AddTransform(new IsolateModelTransform(method));
            tc.AddTransform(new ExternalVariablesTransform());
            tc.AddTransform(new IntermediateVariableTransform());
            tc.AddTransform(new ModelAnalysisTransform());
            tc.AddTransform(new ArrayAnalysisTransform());
            tc.AddTransform(new EqualityPropagationTransform());
            tc.AddTransform(new StocAnalysisTransform(true));
            tc.AddTransform(new MarginalAnalysisTransform());

            tc.AddTransform(new GateTransform(algorithm));
            tc.AddTransform(new IndexingTransform());
            if (useDepthCloning)
            {
                // DepthCloningTransform needs two passes since the first pass may create new code that needs to be transformed.
                // See SwitchDeepArrayCopyTest.
                //tc.AddTransform(new DepthCloningTransform(false));
                // Unfortunately this breaks ArrayUsedAtManyDepths2.
                tc.AddTransform(new DepthCloningTransform(true));
                tc.AddTransform(new ReplicationTransform());
                // IfCutting must be between Depth and Replication because loops are added
                tc.AddTransform(new IfCuttingTransform());
            }
            tc.AddTransform(new DerivedVariableTransform());
            tc.AddTransform(new PowerTransform());
            tc.AddTransform(new ReplicationTransform());
            if (useVariableTransform)
            {
                tc.AddTransform(new VariableTransform(algorithm));
                if (useDepthCloning)
                {
                    tc.AddTransform(new Channel2Transform(useDepthCloning, false));
                }
            }
            // IfCutting must be after Variable because Variable factors send evidence
            tc.AddTransform(new IfCuttingTransform());
            if (useVariableTransform)
            {
                // must do Replication here since Channel2 could have created new loops (see ArrayUsedAtManyDepths3)
                if (useDepthCloning)
                {
                    tc.AddTransform(new ReplicationTransform());
                }
                tc.AddTransform(new Channel2Transform(useDepthCloning, true));
                tc.AddTransform(new PointMassAnalysisTransform());
            }
            else
            {
                tc.AddTransform(new ChannelTransform(algorithm));
            }
            if (algorithm is GibbsSampling)
            {
                tc.AddTransform(new GroupTransform(algorithm));
            }
            //   tc.AddTransform(new HybridAlgorithmTransform(engine.Algorithm));
            tc.AddTransform(new MessageTransform(this, algorithm, factorManager, this.AllowDerivedParents));
            tc.AddTransform(new IncrementTransform(this));
            if (OptimiseInferenceCode)
            {
                // LoopCutting must precede CopyPropagation because you can have situations such as:
                // for(N) {
                //   bool local = ...;
                //   array[N] = Copy(local);
                // }
                // SomeFactor(array);
                var lct = new LoopCuttingTransform(false);
                tc.AddTransform(lct);
                tc.AddTransform(lct); // run again to catch uses before declaration
                // TODO: fix CopyPropagation so it only needs to run once.
                tc.AddTransform(new CopyPropagationTransform());
                tc.AddTransform(new CopyPropagationTransform());
                tc.AddTransform(new CopyPropagationTransform());
                tc.AddTransform(new HoistingTransform(this));
                // LoopCutting must follow Hoisting since new hoist variables can be loop locals
            }
            var lct2 = new LoopCuttingTransform(true);

            tc.AddTransform(lct2);
            tc.AddTransform(lct2); // run again to catch uses before declaration
            if (OptimiseInferenceCode)
            {
                // must run after HoistingTransform
                tc.AddTransform(new LoopRemovalTransform());
            }
            tc.AddTransform(new DependencyAnalysisTransform());
            tc.AddTransform(new PruningTransform());
            tc.AddTransform(new IterationTransform(this));
            tc.AddTransform(new IncrementPruningTransform());
            tc.AddTransform(new InitializerTransform(this));
            if (UseSerialSchedules && !UseExperimentalSerialSchedules)
            {
                tc.AddTransform(new ForwardBackwardTransform(this));
            }
            tc.AddTransform(new SchedulingTransform(this));
            tc.AddTransform(new UniquenessTransform());
            tc.AddTransform(new DependencyPruningTransform());
            if (OptimiseInferenceCode)
            {
                tc.AddTransform(new LoopReversalTransform());
            }
            tc.AddTransform(new LocalAllocationTransform(this));
            tc.AddTransform(new DeadCodeTransform(this, true));
            // add any extra transforms provided by the user
            if (this.ExtraTransforms != null)
            {
                foreach (var transform in this.ExtraTransforms)
                {
                    tc.AddTransform(transform());
                }
            }
            tc.AddTransform(new IterativeProcessTransform(this, algorithm));
            // LoopMerging is required to support offset indexing (see GateModelTests.CaseLoopIndexTest2)
            tc.AddTransform(new LoopMergingTransform());
            tc.AddTransform(new IsIncreasingTransform());
            // Local is required for DistributedTests
            tc.AddTransform(new LocalTransform(this));
            if (OptimiseInferenceCode)
            {
                tc.AddTransform(new DeadCode2Transform(this));
            }
            tc.AddTransform(new ParallelScheduleTransform());
            // All messages after each iteration will be logged to csv files in a folder named with the model name.
            // Use MatlabWriter.WriteFromCsvFolder to convert these to a mat file.
            bool useTracingTransform = false;

            if (TraceAllMessages && useTracingTransform)
            {
                tc.AddTransform(new TracingTransform());
            }
            bool useArraySizeTracing = false;

            if (useArraySizeTracing)
            {
                // This helps isolate memory performance issues.
                tc.AddTransform(new ArraySizeTracingTransform());
            }
            if (UseParallelForLoops)
            {
                tc.AddTransform(new ParallelForTransform());
            }
            tc.AddTransform(new LoggingTransform(this));
            return(tc);
        }