Ejemplo n.º 1
0
        private void addTapeNode(Tensor[] inputs, Tensor result, Func <Tensor, Tensor[]> gradientsFunc)
        {
            var inputsMap = new Dictionary <string, Tensor>();

            for (int i = 0; i < inputs.Length; i++)
            {
                inputsMap.Add(i.ToString(), inputs[i]);
            }

            Func <Tensor, NamedGradientMap> gradient = (Tensor dy) =>
            {
                var res    = gradientsFunc(dy);
                var resMap = new NamedGradientMap();

                var outer = 0;
                foreach (var item in res)
                {
                    resMap.gradient.Add(outer.ToString(), () => { return(item); });
                    outer++;
                }
                return(resMap);
            };

            TapeNode tapeNode = new TapeNode()
            {
                id       = this.nextTapeNodeId++,
                name     = this.activeScope.name,
                inputs   = inputsMap,
                output   = result,
                gradient = gradient
            };

            this.activeTape.Add(tapeNode);
        }
Ejemplo n.º 2
0
        public Tensor runKernel(ForwardFunc forwardFunc,
                                Dictionary <string, Tensor> inputs, Func <Tensor, List <Tensor>, NamedGradientMap> grad = null)
        {
            Tensor                result;
            List <Tensor>         saved    = new List <Tensor>();
            Func <Tensor, Tensor> saveFunc = (Tensor x) =>
            {
                saved.Add(x);
                return(x);
            };
            var scopeName = this.activeScope.name;

            // Stop recording to a tape when running a kernel.
            this.customGradientDepth++;
            result = forwardFunc(this.backend, saveFunc);
            // Continue recording after the kernel is done.
            this.customGradientDepth--;
            if (this.shouldRecord())

            {
                var tapeNode = new TapeNode()
                {
                    id     = this.nextTapeNodeId++,
                    name   = scopeName,
                    inputs = inputs,
                    output = result
                };

                if (grad != null)
                {
                    tapeNode.gradient = (Tensor dy) =>
                    {
                        return(grad(dy, saved));
                    };
                }
                this.activeTape.Add(tapeNode);
            }


            return(result);
        }
Ejemplo n.º 3
0
        public static TapeNode[] getFilteredNodesXToY(TapeNode[] tape, Tensor[] xs, Tensor y)
        {
            // Forward pass to compute all the nodes and Tensors that are transitively a
            // function of x.
            Dictionary <int, bool> tensorsFromX = new Dictionary <int, bool>();
            Dictionary <int, bool> nodesFromX   = new Dictionary <int, bool>();

            for (var i = 0; i < xs.Length; i++)
            {
                tensorsFromX.Add(xs[i].id, true);
            }

            for (int i = 0; i < tape.Length; i++)
            {
                var node       = tape[i];
                var nodeInputs = node.inputs;
                foreach (var input in nodeInputs)
                {
                    var inputT        = input.Value;
                    var anyInputFromX = false;
                    for (var j = 0; j < xs.Length; j++)
                    {
                        if (tensorsFromX.ContainsKey(inputT.id))
                        {
                            tensorsFromX.Add(node.output.id, true);// = true;
                            anyInputFromX = true;
                            nodesFromX.Add(node.id, true);
                            break;
                        }
                    }

                    if (anyInputFromX)
                    {
                        break;
                    }
                }
            }

            // Backwards pass to find all of the nodes and Tensors that lead to y.
            Dictionary <int, bool> tensorsLeadToY = new Dictionary <int, bool>();

            tensorsLeadToY.Add(y.id, true);


            Dictionary <int, bool> nodesToY = new Dictionary <int, bool>();


            for (var i = tape.Length - 1; i >= 0; i--)
            {
                var node       = tape[i];
                var nodeInputs = node.inputs;

                List <Tensor> outputs = new List <Tensor>();
                outputs.Add(node.output);

                for (var j = 0; j < outputs.Count; j++)
                {
                    if (tensorsLeadToY.ContainsKey(outputs[j].id))
                    {
                        foreach (var item in nodeInputs)
                        {
                            if (tensorsLeadToY.ContainsKey(nodeInputs[item.Key].id))
                            {
                                tensorsLeadToY[nodeInputs[item.Key].id] = true;
                            }
                            else
                            {
                                tensorsLeadToY.Add(nodeInputs[item.Key].id, true);
                            }

                            if (nodesToY.ContainsKey(node.id))
                            {
                                nodesToY[node.id] = true;
                            }
                            else
                            {
                                nodesToY.Add(node.id, true);
                            }
                        }
                        break;
                    }
                }
            }

            // Return the paths that come from x and lead to y.

            List <TapeNode> filteredTape = new List <TapeNode>();

            for (var i = 0; i < tape.Length; i++)
            {
                var node = tape[i];

                if (nodesFromX.ContainsKey(node.id) && nodesToY.ContainsKey(node.id))
                {
                    Dictionary <string, Tensor> prunedInputs = new Dictionary <string, Tensor>();
                    foreach (var item in node.inputs)
                    {
                        var nodeInput = item.Value;
                        if (tensorsFromX.ContainsKey(nodeInput.id))
                        {
                            prunedInputs.Add(item.Key, nodeInput);
                        }
                    }
                    TapeNode prunedNode = new TapeNode()
                    {
                        id       = node.id,
                        name     = node.name,
                        gradient = node.gradient
                    };

                    prunedNode.inputs = prunedInputs;
                    prunedNode.output = node.output;
                    filteredTape.Add(prunedNode);
                }
            }
            return(filteredTape.ToArray());
        }