Пример #1
0
        public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
                                                            HParams hparams                          = null,
                                                            string pipeline_config_path              = null,
                                                            int train_steps                          = 0,
                                                            int sample_1_of_n_eval_examples          = 0,
                                                            int sample_1_of_n_eval_on_train_examples = 1)
        {
            var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);

            // Create the input functions for TRAIN/EVAL/PREDICT.
            Action train_input_fn = () => { };

            var eval_input_configs = config.EvalInputReader;

            var    eval_input_fns         = new Action[eval_input_configs.Count];
            var    eval_input_names       = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray();
            Action eval_on_train_input_fn = () => { };
            Action predict_input_fn       = () => { };
            Action model_fn  = () => { };
            var    estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config);

            return(new TrainAndEvalDict
            {
                estimator = estimator,
                train_input_fn = train_input_fn,
                eval_input_fns = eval_input_fns,
                eval_input_names = eval_input_names,
                eval_on_train_input_fn = eval_on_train_input_fn,
                predict_input_fn = predict_input_fn,
                train_steps = train_steps
            });
        }
Пример #2
0
 public Gpt2Trainer(DataSet dataset, Gpt2Encoder encoder, HParams hParams,
                    int batchSize, int sampleLength, Random random)
 {
     this.dataset      = dataset ?? throw new ArgumentNullException(nameof(dataset));
     this.encoder      = encoder ?? throw new ArgumentNullException(nameof(encoder));
     this.hParams      = hParams ?? throw new ArgumentNullException(nameof(hParams));
     this.batchSize    = batchSize;
     this.sampleLength = sampleLength;
     this.random       = random ?? throw new ArgumentNullException(nameof(random));
 }
Пример #3
0
 public TCPServer(HParams hParams)
 {
     _hParams = hParams;
     _tcpPort = _hParams.IPPort.ToInt();
 }
Пример #4
0
        public static Tensor SampleSequence(HParams hParams, int length,
                                            string startToken = null, int?batchSize = null, dynamic context = null,
                                            float temperature = 1, int topK         = 0)
        {
            if (((startToken == null) ^ (context == null)) == false)
            {
                throw new ArgumentException($"Exactly one of {nameof(startToken)} or {nameof(context)} has to be specified");
            }

            SortedDictionary <string, dynamic> Step(HParams @params, Tensor tokens, dynamic past = null)
            {
                var lmOutput = Gpt2Model.Model(hParams: @params, input: tokens, past: past, reuse: _ReuseMode.AUTO_REUSE);

                var    logits   = lmOutput["logits"][Range.All, Range.All, Range.EndAt((int)@params.get("n_vocab"))];
                Tensor presents = lmOutput["present"];

                int?[] pastShape = Gpt2Model.PastShape(hParams: @params, batchSize: batchSize);
                presents.set_shape_(pastShape.Cast <object>());

                return(new SortedDictionary <string, object>
                {
                    ["logits"] = logits,
                    ["presents"] = presents,
                });
            }

            Tensor result = null;

            new name_scope("sample_sequence").Use(_ =>
            {
                // Don't feed the last context token -- leave that to the loop below
                // TODO: Would be slightly faster if we called step on the entire context,
                // rather than leaving the last token transformer calculation to the while loop.
                var contextOutput = Step(hParams, context[Range.All, Range.EndAt(new Index(1, fromEnd: true))]);

                Tensor[] Body(object past, dynamic prev, object output)
                {
                    var nextOutputs = Step(hParams, prev[Range.All, tf.newaxis], past: past);
                    Tensor logits   = nextOutputs["logits"][Range.All, -1, Range.All] / tf.to_float(temperature);
                    logits          = TopLogits(logits, topK: topK);
                    var samples     = tf.multinomial_dyn(logits, num_samples: 1, output_dtype: tf.int32);
                    return(new Tensor[]
                    {
                        tf.concat(new [] { past, nextOutputs["presents"] }, axis: -2),
                        tf.squeeze(samples, axis: new[] { 1 }),
                        tf.concat(new [] { output, samples }, axis: 1),
                    });
                }

                bool True(object _a, object _b, object _c) => true;

                dynamic[] loopVars = new[] {
                    contextOutput["presents"],
                    context[Range.All, -1],
                    context,
                };
                TensorShape[] shapeInvariants = new[] {
                    new TensorShape(Gpt2Model.PastShape(hParams: hParams, batchSize: batchSize)),
                    new TensorShape(batchSize),
                    new TensorShape((int?)batchSize, (int?)null),
                };
                result = tf.while_loop(
                    cond: PythonFunctionContainer.Of <object, object, object, bool>(True),
                    body: PythonFunctionContainer.Of(new Func <object, object, object, Tensor[]>(Body)),
                    parallel_iterations: 10,
                    swap_memory: false,
                    name: null,
                    maximum_iterations: tf.constant(length),
                    loop_vars: loopVars,
                    shape_invariants: shapeInvariants,
                    back_prop: false)
                         [2];
            });
            return(result);
        }
Пример #5
0
 public void SetParams(HParams hParams)
 {
     _hParams           = hParams;
     _hParams.IPAddress = NetworkClass.getComputerIpAddress();
 }
Пример #6
0
        ///// <summary>
        ///// ctore
        ///// </summary>
        ///// <param name="udpPort">port to listen on</param>
        ///// <param name="networkName">network name to check</param>
        //public UDPServer(int udpPort, string networkName)
        //{
        //    _udpPort = udpPort;
        //    SetNetworkName(networkName);
        //    SetTCPPort(_tcpPort);
        //}
        /// <summary>
        /// ctore
        /// </summary>
        /// <param name="hParams">Terminal Params class</param>

        public UDPServer(HParams hParams)
        {
            _hParams = hParams;
        }