示例#1
0
        public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
        {
            var residual = x;

            x = bn1.Call(x, args);
            if (x.IsNDArray)
            {
                x = nd.Activation(x.NdX, ActivationType.Relu);
            }
            else
            {
                x = sym.Activation(x.SymX, ActivationType.Relu);
            }

            if (ds != null)
            {
                residual = ds.Call(x, args);
            }

            x = conv1.Call(x, args);

            x = bn2.Call(x, args);
            if (x.IsNDArray)
            {
                x = nd.Activation(x.NdX, ActivationType.Relu);
            }
            else
            {
                x = sym.Activation(x.SymX, ActivationType.Relu);
            }
            x = conv2.Call(x, args);

            if (x.IsNDArray)
            {
                return(x.NdX + residual.NdX);
            }

            return(x.SymX + residual.SymX);
        }
示例#2
0
        public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
        {
            NDArrayOrSymbol output = null;
            var             weight = args[0];
            var             bias   = args.Length > 1 ? args[1] : null;

            if (x.IsNDArray)
            {
                output = nd.FullyConnected(x.NdX, weight, bias, Units, !UseBias, Flatten_);
            }

            if (x.IsSymbol)
            {
                output = sym.FullyConnected(x.SymX, weight, bias, Units, !UseBias, Flatten_);
            }

            if (Act != null)
            {
                output = Act.HybridForward(output);
            }

            return(output);
        }
        public override (NDArrayOrSymbol, NDArrayOrSymbol[]) Call(NDArrayOrSymbol inputs,
                                                                  params NDArrayOrSymbol[] states)
        {
            _counter++;
            var next_states = new List <NDArrayOrSymbol>();
            var p           = 0;

            foreach (var cell in _childrens.Values)
            {
                if (cell.GetType().Name == "BidirectionalCell")
                {
                    throw new Exception("BidirectionalCell not allowed");
                }
                var n     = cell.StateInfo().Length;
                var state = states.Skip(p).Take(n).ToArray();

                p += n;
                (inputs, state) = cell.Call(inputs, state);
                next_states.AddRange(state);
            }

            return(inputs, new[] { next_states.Sum() });
        }
示例#4
0
        internal static (NDArrayOrSymbol[], int, int) FormatSequence(int?length, NDArrayOrSymbol inputs, string layout,
                                                                     bool merge, string in_layout = null)
        {
            var axis       = layout.IndexOf('T');
            var batch_axis = layout.IndexOf('N');
            var batch_size = 0;
            var in_axis    = !string.IsNullOrWhiteSpace(in_layout) ? in_layout.IndexOf('T') : axis;

            NDArrayOrSymbol[] data_inputs = null;
            if (inputs.IsSymbol)
            {
                if (!merge)
                {
                    if (inputs.SymX.ListOutputs().Count != 1)
                    {
                        throw new Exception("unroll doesn't allow grouped symbol as input. Please convert " +
                                            "to list with list(inputs) first or let unroll handle splitting.");
                    }
                    data_inputs = new NDArrayOrSymbol[] { sym.Split(inputs.SymX, length.Value, in_axis, true) };
                }
            }
            else if (inputs.IsNDArray)
            {
                batch_size = inputs.NdX.Shape[batch_axis];
                if (!merge)
                {
                    if (length != inputs.NdX.Shape[in_axis])
                    {
                        throw new Exception("Invalid length!");
                    }

                    data_inputs = nd.Split(inputs.NdX, inputs.NdX.Shape[in_axis], in_axis, true).NDArrayOrSymbols;
                }
            }

            return(data_inputs, axis, batch_size);
        }
示例#5
0
 private NDArrayOrSymbol AuxForward(NDArrayOrSymbol pred1, NDArrayOrSymbol pred2, NDArrayOrSymbol label, NDArrayOrSymbol sample_weight = null, params object[] args)
 {
     throw new NotImplementedException();
 }
示例#6
0
 public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol pred, NDArrayOrSymbol label, NDArrayOrSymbol sample_weight = null, params object[] args)
 {
     throw new NotImplementedException();
 }
示例#7
0
 public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
 {
     x = Features.Call(x, args);
     x = Output.Call(x, args);
     return(x);
 }
示例#8
0
 public NDArrayOrSymbol HybridForward(object F, NDArrayOrSymbol x)
 {
     x = this.Features.Call(x);
     x = this.Output.Call(x);
     return(x);
 }
示例#9
0
 public NDArrayOrSymbol Predict(NDArrayOrSymbol x)
 {
     throw new NotImplementedException();
 }
示例#10
0
 public new virtual (NDArrayOrSymbol, NDArrayOrSymbol[]) Call(NDArrayOrSymbol inputs,
                                                              params NDArrayOrSymbol[] states)
 {
     return(default);
示例#11
0
 public (NDArrayOrSymbol, NDArrayOrSymbolList) FastPath(NDArrayOrSymbol x)
 {
     throw new NotImplementedException();
 }
示例#12
0
 public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
 {
     return(x * this._act.Call(x));
 }
示例#13
0
 public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
 {
     return(_act.Call(x + 3) / 6);
 }
示例#14
0
        public override (NDArrayOrSymbol, NDArrayOrSymbol[]) HybridForward(NDArrayOrSymbol x,
                                                                           params NDArrayOrSymbol[] args)
        {
            var(cell, p_outputs, p_states) = (BaseCell, ZoneoutOutputs, ZoneoutStates);
            var(next_output, next_states)  = cell.Call(x, args);

            NDArrayOrSymbol mask(float p, NDArrayOrSymbol like)
            {
                if (x.IsNDArray)
                {
                    return(nd.Dropout(nd.OnesLike(x), p));
                }

                return(sym.Dropout(sym.OnesLike(x), p));
            }

            var prev_output = _prev_output;

            if (prev_output == null)
            {
                prev_output = x.IsNDArray
                    ? new NDArrayOrSymbol(nd.ZerosLike(next_output))
                    : new NDArrayOrSymbol(sym.ZerosLike(next_output));
            }

            NDArrayOrSymbol output = null;

            NDArrayOrSymbol[] states = null;
            if (x.IsNDArray)
            {
                output = p_outputs != 0
                    ? new NDArrayOrSymbol(nd.Where(mask(p_outputs, next_output), next_output, prev_output))
                    : next_output;

                if (p_states == 0)
                {
                    states = next_states;
                }
                else
                {
                    next_states.Zip(states,
                                    (new_s, old_s) =>
                    {
                        return(new NDArrayOrSymbol(nd.Where(mask(p_states, new_s), new_s, old_s)));
                    }).ToArray();
                }
            }
            else if (x.IsSymbol)
            {
                output = p_outputs != 0
                    ? new NDArrayOrSymbol(sym.Where(mask(p_outputs, next_output), next_output, prev_output))
                    : next_output;

                if (p_states == 0)
                {
                    states = next_states;
                }
                else
                {
                    next_states.Zip(states,
                                    (new_s, old_s) =>
                    {
                        return(new NDArrayOrSymbol(sym.Where(mask(p_states, new_s), new_s, old_s)));
                    }).ToArray();
                }
            }

            _prev_output = output;

            return(output, states);
        }
示例#15
0
 public override void Reset()
 {
     base.Reset();
     _prev_output = null;
 }
示例#16
0
        public override (NDArrayOrSymbol[], NDArrayOrSymbol[]) Unroll(int length, NDArrayOrSymbol[] inputs,
                                                                      NDArrayOrSymbol[] begin_state = null, string layout = "NTC", bool?merge_outputs = null,
                                                                      Symbol valid_length           = null)
        {
            Reset();
            var axis       = 0;
            var batch_size = 0;

            (inputs, axis, batch_size) = RNNCell.FormatSequence(length, inputs, layout, false);
            var reversed_inputs = RNNCell._reverse_sequences(inputs, length, valid_length);

            begin_state = RNNCell.GetBeginState(this, begin_state, inputs, batch_size);
            var states = begin_state.ToList();
            var l_cell = _childrens["l_cell"];
            var r_cell = _childrens["r_cell"];

            var(l_outputs, l_states) = l_cell.Unroll(length, inputs, states.Take(l_cell.StateInfo().Length).ToArray(),
                                                     layout, merge_outputs, valid_length);
            var(r_outputs, r_states) = r_cell.Unroll(length, inputs, states.Skip(l_cell.StateInfo().Length).ToArray(),
                                                     layout, merge_outputs, valid_length);

            var reversed_r_outputs = RNNCell._reverse_sequences(r_outputs, length, valid_length);

            if (!merge_outputs.HasValue)
            {
                merge_outputs = l_outputs.Length > 1;

                (l_outputs, _, _)          = RNNCell.FormatSequence(null, l_outputs, layout, merge_outputs.Value);
                (reversed_r_outputs, _, _) =
                    RNNCell.FormatSequence(null, reversed_r_outputs, layout, merge_outputs.Value);
            }

            NDArrayOrSymbol[] outputs = null;
            if (merge_outputs.Value)
            {
                if (reversed_r_outputs[0].IsNDArray)
                {
                    reversed_r_outputs = new NDArrayOrSymbol[]
                    { nd.Stack(reversed_r_outputs.ToList().ToNDArrays(), reversed_r_outputs.Length, axis) }
                }
                ;
                else
                {
                    reversed_r_outputs = new NDArrayOrSymbol[]
                    { sym.Stack(reversed_r_outputs.ToList().ToSymbols(), reversed_r_outputs.Length, axis) }
                };

                var concatList = l_outputs.ToList();
                concatList.AddRange(reversed_r_outputs);
                if (reversed_r_outputs[0].IsNDArray)
                {
                    outputs = new NDArrayOrSymbol[] { nd.Concat(concatList.ToList().ToNDArrays(), 2) }
                }
                ;
                else
                {
                    outputs = new NDArrayOrSymbol[] { sym.Concat(concatList.ToList().ToSymbols(), 2) }
                };
            }
            else
            {
                var outputs_temp = new List <NDArrayOrSymbol>();
                for (var i = 0; i < l_outputs.Length; i++)
                {
                    var l_o = l_outputs[i];
                    var r_o = reversed_r_outputs[i];
                    if (l_o.IsNDArray)
                    {
                        outputs_temp.Add(nd.Concat(new NDArray[] { l_o, r_o }));
                    }
                    else
                    {
                        outputs_temp.Add(sym.Concat(new Symbol[] { l_o, r_o }, 1, symbol_name: $"{_output_prefix}t{i}"));
                    }
                }

                outputs = outputs_temp.ToArray();
                outputs_temp.Clear();
            }


            if (valid_length != null)
            {
                outputs = RNNCell.MaskSequenceVariableLength(outputs, length, valid_length, axis, merge_outputs.Value);
            }

            states.Clear();
            states.AddRange(l_states);
            states.AddRange(r_states);

            return(outputs, states.ToArray());
        }
示例#17
0
 public NDArrayOrSymbol FlipInference(NDArrayOrSymbol image)
 {
     throw new NotImplementedException();
 }
示例#18
0
 public NDArrayOrSymbol Call(NDArrayOrSymbol image)
 {
     throw new NotImplementedException();
 }
示例#19
0
 public NDArrayOrSymbol SlowPath(NDArrayOrSymbol x, NDArrayOrSymbolList lateral)
 {
     throw new NotImplementedException();
 }
        public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol pred, NDArrayOrSymbol label,
                                                      NDArrayOrSymbol sample_weight = null, params object[] args)
        {
            if (label.IsNDArray)
            {
                label = nd.ReshapeLike(label, pred);
            }
            else
            {
                label = sym.ReshapeLike(label, pred);
            }

            NDArrayOrSymbol pos_weight = null;
            NDArrayOrSymbol loss       = null;

            if (args.Length > 0)
            {
                pos_weight = args[0] is NDArray
                    ? new NDArrayOrSymbol((NDArray)args[0])
                    : new NDArrayOrSymbol((Symbol)args[0]);
            }

            if (!_from_sigmoid)
            {
                if (pos_weight == null)
                {
                    if (label.IsNDArray)
                    {
                        loss = nd.Relu(pred) - pred.NdX * label.NdX +
                               nd.Activation(nd.Negative(nd.Abs(pred)), ActivationType.Softrelu);
                    }
                    else
                    {
                        loss = sym.Relu(pred) - pred.SymX * label.SymX +
                               sym.Activation(sym.Negative(sym.Abs(pred)), ActivationType.Softrelu);
                    }
                }
                else
                {
                    if (label.IsNDArray)
                    {
                        var log_weight = 1 + nd.BroadcastMul(pos_weight.NdX - 1, label);
                        loss = nd.Relu(pred) - pred.NdX * label.NdX + log_weight
                               + nd.Activation(nd.Negative(nd.Abs(pred)),
                                               ActivationType.Softrelu)
                               + nd.Relu(nd.Negative(pred));
                    }
                    else
                    {
                        var log_weight = 1 + sym.BroadcastMul(pos_weight.SymX - 1, label);
                        loss = sym.Relu(pred) - pred.SymX * label.SymX + log_weight
                               + sym.Activation(sym.Negative(sym.Abs(pred)),
                                                ActivationType.Softrelu)
                               + sym.Relu(sym.Negative(pred));
                    }
                }
            }
            else
            {
                var eps = 1e-12f;
                if (pos_weight == null)
                {
                    if (label.IsNDArray)
                    {
                        loss = nd.Negative(nd.Log(pred.NdX + eps) * label.NdX
                                           + nd.Log(1 - pred.NdX + eps) * (1 - label.NdX));
                    }
                    else
                    {
                        loss = sym.Negative(sym.Log(pred.SymX + eps) * label.SymX
                                            + sym.Log(1 - pred.SymX + eps) * (1 - label.SymX));
                    }
                }
                else
                {
                    if (label.IsNDArray)
                    {
                        loss = nd.Negative(nd.BroadcastMul(nd.Log(pred.NdX + eps) * label.NdX, pos_weight)
                                           + nd.Log(1 - pred.NdX + eps) * (1 - label.NdX));
                    }
                    else
                    {
                        loss = sym.Negative(sym.BroadcastMul(sym.Log(pred.SymX + eps) * label.SymX, pos_weight)
                                            + sym.Log(1 - pred.SymX + eps) * (1 - label.SymX));
                    }
                }
            }

            loss = ApplyWeighting(loss, Weight, sample_weight);
            if (loss.IsNDArray)
            {
                return(nd.Mean(loss, BatchAxis.Value, exclude: true));
            }

            return(sym.Mean(loss, BatchAxis.Value, exclude: true));
        }
示例#21
0
 public NDArrayOrSymbol Demo(NDArrayOrSymbol x)
 {
     throw new NotImplementedException();
 }
示例#22
0
 public override NDArrayOrSymbol Forward(NDArrayOrSymbol input, params NDArrayOrSymbol[] args)
 {
     return(Function(input));
 }
示例#23
0
 public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
 {
     return(x);
 }
示例#24
0
 public virtual (NDArrayOrSymbol, NDArrayOrSymbol) BaseForward(NDArrayOrSymbol x)
 {
     throw new NotImplementedException();
 }
示例#25
0
 public override (NDArrayOrSymbol, NDArrayOrSymbol[]) HybridForward(NDArrayOrSymbol x,
                                                                    params NDArrayOrSymbol[] args)
 {
     return(default);
示例#26
0
 public virtual NDArrayOrSymbol Evaluate(NDArrayOrSymbol x)
 {
     throw new NotImplementedException();
 }
示例#27
0
 public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
 {
     throw new NotImplementedException();
 }
示例#28
0
 public abstract NDArrayOrSymbol Forward(NDArrayOrSymbol input, params NDArrayOrSymbol[] args);
示例#29
0
        public override (NDArrayOrSymbol, NDArrayOrSymbol[]) HybridForward(NDArrayOrSymbol x,
                                                                           params NDArrayOrSymbol[] args)
        {
            var             prefix       = $"t{_counter}_";
            var             prev_state_h = args[0];
            var             i2h_weight   = args[1];
            var             h2h_weight   = args[2];
            var             i2h_bias     = args[3];
            var             h2h_bias     = args[4];
            NDArrayOrSymbol next_h       = null;

            if (x.IsNDArray)
            {
                var i2h      = nd.FullyConnected(x, i2h_weight, i2h_bias, _hidden_size * 3);
                var h2h      = nd.FullyConnected(prev_state_h, h2h_weight, h2h_bias, _hidden_size * 3);
                var i2hsplit = nd.Split(i2h, 3);
                var i2h_r    = i2hsplit[0];
                var i2h_z    = i2hsplit[1];
                i2h = i2hsplit[2];

                var h2hsplit = nd.Split(h2h, 3);
                var h2h_r    = h2hsplit[0];
                var h2h_z    = h2hsplit[1];
                h2h = h2hsplit[2];

                var reset_gate  = Activation(nd.ElemwiseAdd(i2h_r, h2h_r), "sigmoid");
                var update_gate = Activation(nd.ElemwiseAdd(i2h_z, h2h_z), "sigmoid");
                var next_h_tmp  = Activation(nd.ElemwiseAdd(i2h,
                                                            nd.ElemwiseMul(reset_gate, h2h)),
                                             "tanh");
                var ones = nd.OnesLike(update_gate);
                next_h = nd.ElemwiseAdd(nd.ElemwiseMul(nd.ElemwiseSub(ones, update_gate),
                                                       next_h_tmp),
                                        nd.ElemwiseMul(update_gate, prev_state_h));
            }
            else
            {
                var i2h = sym.FullyConnected(x, i2h_weight, i2h_bias, _hidden_size * 3, symbol_name: prefix + "i2h");
                var h2h = sym.FullyConnected(prev_state_h, h2h_weight, h2h_bias, _hidden_size * 3,
                                             symbol_name: prefix + "h2h");
                var i2hsplit = sym.Split(i2h, 3, symbol_name: prefix + "i2h_slice");
                var i2h_r    = i2hsplit[0];
                var i2h_z    = i2hsplit[1];
                i2h = i2hsplit[2];

                var h2hsplit = sym.Split(h2h, 3, symbol_name: prefix + "h2h_slice");
                var h2h_r    = h2hsplit[0];
                var h2h_z    = h2hsplit[1];
                h2h = h2hsplit[2];

                var reset_gate = Activation(sym.ElemwiseAdd(i2h_r, h2h_r, prefix + "plus0"), "sigmoid",
                                            name: prefix + "r_act");
                var update_gate = Activation(sym.ElemwiseAdd(i2h_z, h2h_z, prefix + "plus1"), "sigmoid",
                                             name: prefix + "z_act");
                var next_h_tmp = Activation(
                    sym.ElemwiseAdd(i2h, sym.ElemwiseMul(reset_gate, h2h, prefix + "mul0"), prefix + "plus2"),
                    "tanh", name: prefix + "h_act");
                var ones = sym.OnesLike(update_gate, prefix + "ones_like0");
                next_h = sym.ElemwiseAdd(sym.ElemwiseMul(sym.ElemwiseSub(ones, update_gate, prefix + "minus0"),
                                                         next_h_tmp, prefix + "mul1"),
                                         sym.ElemwiseMul(update_gate, prev_state_h, prefix + "mul2"), prefix + "out");
            }

            return(next_h, new[] { next_h });
        }
示例#30
0
        public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol pred, NDArrayOrSymbol label, NDArrayOrSymbol sample_weight = null, params object[] args)
        {
            if (pred.IsNDArray)
            {
                return(F(pred.NdX, label));
            }

            return(F(pred.SymX, label));
        }