private static void _patchFor <T>(Tensor <T> .For @for, Loop loop, Patch substitutions, List <IFor> patchFors) { // the output has been patched already, it should be in substitutions. var patchOutput = (Tensor <T>)@for.OutputInfo?.Patch(substitutions); // we just put the patched version of the variable in substitutions var patchVar = (Tensor <T> .Var)@for.RecursiveVariable?.Patch(substitutions); // we need to patch the expression to use the new variables var patchExpr = (Tensor <T>)@for.Expression.Patch(substitutions); var patchFor = new Tensor <T> .For(loop, @for.Index, patchExpr, patchOutput, patchVar); patchFors.Add(patchFor); substitutions.Add_(@for, patchFor); }
private static ITensorVar _patchVar(ITensorVar v, Patch substitutions) { var lastChar = v.Name[v.Name.Length - 1]; var endsWithDigit = lastChar >= '0' && lastChar <= '9'; var name = v.Name; if (endsWithDigit) { name = v.Name.Substring(0, v.Name.Length - 1) + (char)((lastChar - '0') + 1); } else { name += "_1"; } ITensorVar res; if (v is Tensor <float> .Var) { res = new Tensor <float> .Var(v.Shape.Patch(substitutions), name); } else if (v is Tensor <int> .Var) { res = new Tensor <int> .Var(v.Shape.Patch(substitutions), name); } else if (v is Tensor <double> .Var) { res = new Tensor <double> .Var(v.Shape.Patch(substitutions), name); } else { throw new NotImplementedException(); } substitutions.Add_(v, res); return(res); }
internal Loop Patch(Patch substitutions) { Loop result; if (substitutions.TryGetValue(this, out result)) { return(result); } // patch the length var patchLength = (Scalar <int>)Length.Patch(substitutions); // patch the sequences and outputInfo from the loop var patchSequences = Sequences.Patch(substitutions); var outputsInfo = Fors.Select(f => f.OutputInfo).ToArray(); var patchOutputsInfo = outputsInfo.Patch(substitutions); // the expression uses the loop variables // we do a first patch to detect if we need to create a new loop var blankSubstitutions = new Patch(substitutions); var expressions = Fors.Select(f => f.Expression).ToArray(); var patchExpressions = expressions.Patch(blankSubstitutions); if (patchLength == Length && patchOutputsInfo == outputsInfo && patchSequences == Sequences && patchExpressions == expressions) { // the expressions have been correctly patched foreach (var(expr, value) in (expressions, patchExpressions).Zip()) { substitutions.Add_(expr, value); } result = this; } else { // we need to create a new loop in particular we need to create new hidden variables for the loop var loopName = LoopName(); // create the patched Loop result = new Loop(loopName, patchLength); // created variables are automatically added to `substitutions` var createdVariables = Variables.Select(v => _patchVar(v, substitutions)).ToList(); for (int i = 0; i < Sequences.Count; ++i) { var seq = patchSequences[i]; var patchVar = createdVariables[i]; patchVar.Match( (Tensor <float> .Var varF) => (ITensorVar)result.AddSeq((Tensor <float>)seq, varF, SequenceAxes[i]), (Tensor <int> .Var varI) => result.AddSeq((Tensor <int>)seq, varI, SequenceAxes[i]), null ); } foreach (var @for in Fors) { IFor patchFor; // we need to patch the expression to use the new variables var patchExpr = (ITensor)@for.Expression.Patch(substitutions); if (@for.IsRecursive) { // the output has been patched already, it should be in substitutions. var patchOutput = (ITensor)@for.OutputInfo.Patch(substitutions); // we just put the patched version of the variable in substitutions var patchVar = (ITensorVar)@for.RecursiveVariable.Patch(substitutions); patchFor = result.AddRecursive_(patchExpr, patchOutput, patchVar); } else { patchFor = result.AddOutput_(patchExpr); } substitutions.Add_(@for, patchFor); } } substitutions.Add(this, result); return(result); }