public TextClassificationModel(TextClassificationTrainer.Options options, Vocabulary vocabulary, int numClasses)
     : base(options, vocabulary)
 {
     _predictionHead = new PredictionHead(
         inputDim: Options.EncoderOutputDim,
         numClasses: numClasses,
         dropoutRate: Options.PoolerDropout);
     Initialize();
     RegisterComponents();
 }
Ejemplo n.º 2
0
 /// <summary>
 /// Create and return an optimizer according to <paramref name="options"/>.
 /// </summary>
 /// <param name="options"></param>
 /// <param name="parameters">The parameters to be optimized by the optimizer.</param>
 public static BaseOptimizer GetOptimizer(TextClassificationTrainer.Options options, IEnumerable <Parameter> parameters)
 {
     return(new Adam(options, parameters));
     //var optimizerName = options.Optimizer.ToLower();
     //return optimizerName switch
     //{
     //    "adam" => new Adam(options, parameters),
     //    //"sgd" => new Sgd(options, parameters),
     //    _ => throw new NotSupportedException($"{optimizerName} not supported yet!"),
     //};
 }
Ejemplo n.º 3
0
#pragma warning restore CA1024 // Use properties where appropriate

        protected BaseModel(TextClassificationTrainer.Options options, Vocabulary vocabulary)
            : base(nameof(BaseModel))
        {
            vocabulary = vocabulary ?? throw new ArgumentNullException(nameof(vocabulary));
            Options    = options ?? throw new ArgumentNullException(nameof(options));
            Encoder    = new TransformerEncoder(
                paddingIdx: vocabulary.PadIndex,
                vocabSize: vocabulary.Count,
                dropout: Options.Dropout,
                attentionDropout: Options.AttentionDropout,
                activationDropout: Options.ActivationDropout,
                activationFn: Options.ActivationFunction,
                dynamicDropout: Options.DynamicDropout,
                maxSeqLen: Options.MaxSequenceLength,
                embedSize: Options.EmbeddingDim,
                arches: Options.Arches?.ToList(),
                numSegments: 0,
                encoderNormalizeBefore: Options.EncoderNormalizeBefore,
                numEncoderLayers: Options.EncoderLayers,
                applyBertInit: true,
                freezeTransfer: Options.FreezeTransfer);
        }
Ejemplo n.º 4
0
 protected BaseOptimizer(string name, TextClassificationTrainer.Options options, IEnumerable <Parameter> parameters)
 {
     Name       = name;
     Options    = options;
     Parameters = parameters.ToArray();
 }
Ejemplo n.º 5
0
 public Adam(TextClassificationTrainer.Options options, IEnumerable <Parameter> parameters)
     : base(nameof(Adam), options, parameters)
 {
     Optimizer = torch.optim.Adam(Parameters, options.LearningRate[0],
                                  options.AdamBetas[0], options.AdamBetas[1], options.AdamEps, options.WeightDecay);
 }