private void addTapeNode(Tensor[] inputs, Tensor result, Func <Tensor, Tensor[]> gradientsFunc) { var inputsMap = new Dictionary <string, Tensor>(); for (int i = 0; i < inputs.Length; i++) { inputsMap.Add(i.ToString(), inputs[i]); } Func <Tensor, NamedGradientMap> gradient = (Tensor dy) => { var res = gradientsFunc(dy); var resMap = new NamedGradientMap(); var outer = 0; foreach (var item in res) { resMap.gradient.Add(outer.ToString(), () => { return(item); }); outer++; } return(resMap); }; TapeNode tapeNode = new TapeNode() { id = this.nextTapeNodeId++, name = this.activeScope.name, inputs = inputsMap, output = result, gradient = gradient }; this.activeTape.Add(tapeNode); }
public Tensor runKernel(ForwardFunc forwardFunc, Dictionary <string, Tensor> inputs, Func <Tensor, List <Tensor>, NamedGradientMap> grad = null) { Tensor result; List <Tensor> saved = new List <Tensor>(); Func <Tensor, Tensor> saveFunc = (Tensor x) => { saved.Add(x); return(x); }; var scopeName = this.activeScope.name; // Stop recording to a tape when running a kernel. this.customGradientDepth++; result = forwardFunc(this.backend, saveFunc); // Continue recording after the kernel is done. this.customGradientDepth--; if (this.shouldRecord()) { var tapeNode = new TapeNode() { id = this.nextTapeNodeId++, name = scopeName, inputs = inputs, output = result }; if (grad != null) { tapeNode.gradient = (Tensor dy) => { return(grad(dy, saved)); }; } this.activeTape.Add(tapeNode); } return(result); }
public static TapeNode[] getFilteredNodesXToY(TapeNode[] tape, Tensor[] xs, Tensor y) { // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. Dictionary <int, bool> tensorsFromX = new Dictionary <int, bool>(); Dictionary <int, bool> nodesFromX = new Dictionary <int, bool>(); for (var i = 0; i < xs.Length; i++) { tensorsFromX.Add(xs[i].id, true); } for (int i = 0; i < tape.Length; i++) { var node = tape[i]; var nodeInputs = node.inputs; foreach (var input in nodeInputs) { var inputT = input.Value; var anyInputFromX = false; for (var j = 0; j < xs.Length; j++) { if (tensorsFromX.ContainsKey(inputT.id)) { tensorsFromX.Add(node.output.id, true);// = true; anyInputFromX = true; nodesFromX.Add(node.id, true); break; } } if (anyInputFromX) { break; } } } // Backwards pass to find all of the nodes and Tensors that lead to y. Dictionary <int, bool> tensorsLeadToY = new Dictionary <int, bool>(); tensorsLeadToY.Add(y.id, true); Dictionary <int, bool> nodesToY = new Dictionary <int, bool>(); for (var i = tape.Length - 1; i >= 0; i--) { var node = tape[i]; var nodeInputs = node.inputs; List <Tensor> outputs = new List <Tensor>(); outputs.Add(node.output); for (var j = 0; j < outputs.Count; j++) { if (tensorsLeadToY.ContainsKey(outputs[j].id)) { foreach (var item in nodeInputs) { if (tensorsLeadToY.ContainsKey(nodeInputs[item.Key].id)) { tensorsLeadToY[nodeInputs[item.Key].id] = true; } else { tensorsLeadToY.Add(nodeInputs[item.Key].id, true); } if (nodesToY.ContainsKey(node.id)) { nodesToY[node.id] = true; } else { nodesToY.Add(node.id, true); } } break; } } } // Return the paths that come from x and lead to y. List <TapeNode> filteredTape = new List <TapeNode>(); for (var i = 0; i < tape.Length; i++) { var node = tape[i]; if (nodesFromX.ContainsKey(node.id) && nodesToY.ContainsKey(node.id)) { Dictionary <string, Tensor> prunedInputs = new Dictionary <string, Tensor>(); foreach (var item in node.inputs) { var nodeInput = item.Value; if (tensorsFromX.ContainsKey(nodeInput.id)) { prunedInputs.Add(item.Key, nodeInput); } } TapeNode prunedNode = new TapeNode() { id = node.id, name = node.name, gradient = node.gradient }; prunedNode.inputs = prunedInputs; prunedNode.output = node.output; filteredTape.Add(prunedNode); } } return(filteredTape.ToArray()); }