Exemplo n.º 1
0
        private static void _patchFor <T>(Tensor <T> .For @for, Loop loop, Patch substitutions, List <IFor> patchFors)
        {
            // the output has been patched already, it should be in substitutions.
            var patchOutput = (Tensor <T>)@for.OutputInfo?.Patch(substitutions);
            // we just put the patched version of the variable in substitutions
            var patchVar = (Tensor <T> .Var)@for.RecursiveVariable?.Patch(substitutions);
            // we need to patch the expression to use the new variables
            var patchExpr = (Tensor <T>)@for.Expression.Patch(substitutions);

            var patchFor = new Tensor <T> .For(loop, @for.Index, patchExpr, patchOutput, patchVar);

            patchFors.Add(patchFor);
            substitutions.Add_(@for, patchFor);
        }
Exemplo n.º 2
0
        private static ITensorVar _patchVar(ITensorVar v, Patch substitutions)
        {
            var lastChar      = v.Name[v.Name.Length - 1];
            var endsWithDigit = lastChar >= '0' && lastChar <= '9';
            var name          = v.Name;

            if (endsWithDigit)
            {
                name = v.Name.Substring(0, v.Name.Length - 1) + (char)((lastChar - '0') + 1);
            }
            else
            {
                name += "_1";
            }
            ITensorVar res;

            if (v is Tensor <float> .Var)
            {
                res = new Tensor <float> .Var(v.Shape.Patch(substitutions), name);
            }
            else if (v is Tensor <int> .Var)
            {
                res = new Tensor <int> .Var(v.Shape.Patch(substitutions), name);
            }
            else if (v is Tensor <double> .Var)
            {
                res = new Tensor <double> .Var(v.Shape.Patch(substitutions), name);
            }
            else
            {
                throw new NotImplementedException();
            }

            substitutions.Add_(v, res);
            return(res);
        }
Exemplo n.º 3
0
        internal Loop Patch(Patch substitutions)
        {
            Loop result;

            if (substitutions.TryGetValue(this, out result))
            {
                return(result);
            }

            // patch the length
            var patchLength = (Scalar <int>)Length.Patch(substitutions);

            // patch the sequences and outputInfo from the loop
            var patchSequences   = Sequences.Patch(substitutions);
            var outputsInfo      = Fors.Select(f => f.OutputInfo).ToArray();
            var patchOutputsInfo = outputsInfo.Patch(substitutions);

            // the expression uses the loop variables
            // we do a first patch to detect if we need to create a new loop
            var blankSubstitutions = new Patch(substitutions);
            var expressions        = Fors.Select(f => f.Expression).ToArray();
            var patchExpressions   = expressions.Patch(blankSubstitutions);

            if (patchLength == Length &&
                patchOutputsInfo == outputsInfo &&
                patchSequences == Sequences &&
                patchExpressions == expressions)
            {
                // the expressions have been correctly patched
                foreach (var(expr, value) in (expressions, patchExpressions).Zip())
                {
                    substitutions.Add_(expr, value);
                }
                result = this;
            }
            else
            {
                // we need to create a new loop in particular we need to create new hidden variables for the loop
                var loopName = LoopName();
                // create the patched Loop
                result = new Loop(loopName, patchLength);

                // created variables are automatically added to `substitutions`
                var createdVariables = Variables.Select(v => _patchVar(v, substitutions)).ToList();

                for (int i = 0; i < Sequences.Count; ++i)
                {
                    var seq      = patchSequences[i];
                    var patchVar = createdVariables[i];
                    patchVar.Match(
                        (Tensor <float> .Var varF) => (ITensorVar)result.AddSeq((Tensor <float>)seq, varF, SequenceAxes[i]),
                        (Tensor <int> .Var varI) => result.AddSeq((Tensor <int>)seq, varI, SequenceAxes[i]),
                        null
                        );
                }

                foreach (var @for in Fors)
                {
                    IFor patchFor;

                    // we need to patch the expression to use the new variables
                    var patchExpr = (ITensor)@for.Expression.Patch(substitutions);
                    if (@for.IsRecursive)
                    {
                        // the output has been patched already, it should be in substitutions.
                        var patchOutput = (ITensor)@for.OutputInfo.Patch(substitutions);
                        // we just put the patched version of the variable in substitutions
                        var patchVar = (ITensorVar)@for.RecursiveVariable.Patch(substitutions);

                        patchFor = result.AddRecursive_(patchExpr, patchOutput, patchVar);
                    }
                    else
                    {
                        patchFor = result.AddOutput_(patchExpr);
                    }

                    substitutions.Add_(@for, patchFor);
                }
            }
            substitutions.Add(this, result);
            return(result);
        }