public TextClassificationModel(TextClassificationTrainer.Options options, Vocabulary vocabulary, int numClasses) : base(options, vocabulary) { _predictionHead = new PredictionHead( inputDim: Options.EncoderOutputDim, numClasses: numClasses, dropoutRate: Options.PoolerDropout); Initialize(); RegisterComponents(); }
/// <summary> /// Create and return an optimizer according to <paramref name="options"/>. /// </summary> /// <param name="options"></param> /// <param name="parameters">The parameters to be optimized by the optimizer.</param> public static BaseOptimizer GetOptimizer(TextClassificationTrainer.Options options, IEnumerable <Parameter> parameters) { return(new Adam(options, parameters)); //var optimizerName = options.Optimizer.ToLower(); //return optimizerName switch //{ // "adam" => new Adam(options, parameters), // //"sgd" => new Sgd(options, parameters), // _ => throw new NotSupportedException($"{optimizerName} not supported yet!"), //}; }
#pragma warning restore CA1024 // Use properties where appropriate protected BaseModel(TextClassificationTrainer.Options options, Vocabulary vocabulary) : base(nameof(BaseModel)) { vocabulary = vocabulary ?? throw new ArgumentNullException(nameof(vocabulary)); Options = options ?? throw new ArgumentNullException(nameof(options)); Encoder = new TransformerEncoder( paddingIdx: vocabulary.PadIndex, vocabSize: vocabulary.Count, dropout: Options.Dropout, attentionDropout: Options.AttentionDropout, activationDropout: Options.ActivationDropout, activationFn: Options.ActivationFunction, dynamicDropout: Options.DynamicDropout, maxSeqLen: Options.MaxSequenceLength, embedSize: Options.EmbeddingDim, arches: Options.Arches?.ToList(), numSegments: 0, encoderNormalizeBefore: Options.EncoderNormalizeBefore, numEncoderLayers: Options.EncoderLayers, applyBertInit: true, freezeTransfer: Options.FreezeTransfer); }
protected BaseOptimizer(string name, TextClassificationTrainer.Options options, IEnumerable <Parameter> parameters) { Name = name; Options = options; Parameters = parameters.ToArray(); }
public Adam(TextClassificationTrainer.Options options, IEnumerable <Parameter> parameters) : base(nameof(Adam), options, parameters) { Optimizer = torch.optim.Adam(Parameters, options.LearningRate[0], options.AdamBetas[0], options.AdamBetas[1], options.AdamEps, options.WeightDecay); }