public void Should_Return_Non_Existent_Domain_Because_No_Entry_Matches()
 {
     var dnsResolver = Mock.Of<IDnsResolver>(mock => mock.Resolve(It.IsAny<DnsRequest>()) == new DnsResolutionResult());
     IDnsQueryHandler queryHandler = new DnsQueryHandler(dnsResolver);
     var message = new DnsMessage() {Questions = new List<DnsQuestion>() {new DnsQuestion("test", RecordType.A, RecordClass.Any)}};
     DnsMessageBase processedMessage = queryHandler.HandleQuery(message, null, ProtocolType.IP);
     processedMessage.ReturnCode.Should().Be(ReturnCode.NxDomain);
 }
 public void Should_Return_DNS_Error_Because_No_Entry_Was_Requested()
 {
     var dnsResolver = Mock.Of<IDnsResolver>();
     IDnsQueryHandler queryHandler = new DnsQueryHandler(dnsResolver);
     var message = new DnsMessage();
     DnsMessageBase processedMessage = queryHandler.HandleQuery(message, null, ProtocolType.IP);
     processedMessage.ReturnCode.Should().Be(ReturnCode.ServerFailure);
 }
Beispiel #3
0
        private void ProcessQuery(DnsMessage query)
        {
            if (!query.Questions.Any())
            {
                query.ReturnCode = ReturnCode.ServerFailure;
                return;
            }

            ProcessQuestion(query, query.Questions[0]);
        }
 public void Should_Return_No_Error_Because_Entry_Matches()
 {
     var dnsResolver = Mock.Of<IDnsResolver>(mock => mock.Resolve(It.IsAny<DnsRequest>()) == new DnsResolutionResult()
     {
         DnsEntry = new DnsEntry() { Name = "test", IpV4 = "0.0.0.0"}
     });
     IDnsQueryHandler queryHandler = new DnsQueryHandler(dnsResolver);
     var message = new DnsMessage() { Questions = new List<DnsQuestion>() { new DnsQuestion("test", RecordType.A, RecordClass.Any) } };
     queryHandler.HandleQuery(message, null, ProtocolType.IP);
     message.ReturnCode.Should().Be(ReturnCode.NoError);
 }
Beispiel #5
0
        private void ProcessQuestion(DnsMessage message, DnsQuestion dnsQuestion)
        {
            var dnsRequest = new DnsRequest() { Name = dnsQuestion.Name, Type = dnsQuestion.RecordType.ToString() };
            var result = dnsResolver.Resolve(dnsRequest);
            if (result.DnsEntry == null)
            {
                message.ReturnCode = ReturnCode.NxDomain;
                return;
            }

            message.AnswerRecords.Add(new ARecord(dnsRequest.Name, 3600, IPAddress.Parse(result.DnsEntry.IpV4)));
        }
        internal static DnsMessageBase CreateByFlag(byte[] data, DnsServer.SelectTsigKey tsigKeySelector, byte[] originalMac)
        {
            int    flagPosition = 2;
            ushort flags        = ParseUShort(data, ref flagPosition);

            DnsMessageBase res;

            switch ((OperationCode)((flags & 0x7800) >> 11))
            {
            case OperationCode.Update:
                res = new DnsUpdateMessage();
                break;

            default:
                res = new DnsMessage();
                break;
            }

            res.ParseInternal(data, tsigKeySelector, originalMac);

            return(res);
        }
        /// <summary>
        ///   Creates a new instance of the DnsMessage as response to the current instance
        /// </summary>
        /// <returns>A new instance of the DnsMessage as response to the current instance</returns>
        public DnsMessage CreateResponseInstance()
        {
            DnsMessage result = new DnsMessage()
            {
                TransactionID      = TransactionID,
                IsEDnsEnabled      = IsEDnsEnabled,
                IsQuery            = false,
                OperationCode      = OperationCode,
                IsRecursionDesired = IsRecursionDesired,
                IsCheckingDisabled = IsCheckingDisabled,
                IsDnsSecOk         = IsDnsSecOk,
                Questions          = new List <DnsQuestion>(Questions),
            };

            if (IsEDnsEnabled)
            {
                result.EDnsOptions.Version        = EDnsOptions.Version;
                result.EDnsOptions.UdpPayloadSize = EDnsOptions.UdpPayloadSize;
            }

            return(result);
        }
Beispiel #8
0
        internal static DnsMessageBase CreateByFlag(byte[] data)
        {
            int    flagPosition = 2;
            ushort flags        = ParseUShort(data, ref flagPosition);

            DnsMessageBase res;

            switch ((OperationCode)((flags & 0x7800) >> 11))
            {
            case OperationCode.Update:
                res = new DnsUpdateMessage();
                break;

            default:
                res = new DnsMessage();
                break;
            }

            res.ParseInternal(data);

            return(res);
        }
Beispiel #9
0
        public static DnsMessageBase Create(byte[] resultData, bool isRequest, DnsServer.SelectTsigKey tsigKeySelector, byte[] originalMac)
        {
            int flagPosition = 2;

            ushort flags = ParseUShort(resultData, ref flagPosition);

            DnsMessageBase res;

            switch ((OperationCode)(flags & 0x7800))
            {
            case OperationCode.Update:
                res = new DnsUpdateMessage();
                break;

            default:
                res = new DnsMessage();
                break;
            }

            res.Parse(resultData, isRequest, tsigKeySelector, originalMac);

            return(res);
        }
		/// <summary>
		///   Queries a dns server for specified records.
		/// </summary>
		/// <param name="name"> Domain, that should be queried </param>
		/// <param name="recordType"> Type the should be queried </param>
		/// <param name="recordClass"> Class the should be queried </param>
		/// <param name="options"> Options for the query </param>
		/// <returns> The complete response of the dns server </returns>
		public DnsMessage Resolve(DomainName name, RecordType recordType = RecordType.A, RecordClass recordClass = RecordClass.INet, DnsQueryOptions options = null)
		{
			if (name == null)
				throw new ArgumentNullException(nameof(name), "Name must be provided");

			DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };

			if (options == null)
			{
				message.IsRecursionDesired = true;
				message.IsEDnsEnabled = true;
			}
			else
			{
				message.IsRecursionDesired = options.IsRecursionDesired;
				message.IsCheckingDisabled = options.IsCheckingDisabled;
				message.EDnsOptions = options.EDnsOptions;
			}

			message.Questions.Add(new DnsQuestion(name, recordType, recordClass));

			return SendMessage(message);
		}
Beispiel #11
0
        private async Task ForwardMessage(DnsMessage message, UdpReceiveResult originalUdpMessage,
            IPEndPoint targetNameServer, int queryTimeout,
            bool useCompressionMutation)
        {
            DnsQuestion question = null;
            if (message.Questions.Count > 0)
                question = message.Questions[0];

            byte[] responseBuffer = null;
            try
            {
                if ((Equals(targetNameServer.Address, IPAddress.Loopback) ||
                     Equals(targetNameServer.Address, IPAddress.IPv6Loopback)) &&
                    targetNameServer.Port == ((IPEndPoint) _udpListener.Client.LocalEndPoint).Port)
                    throw new InfiniteForwardingException(question);

                byte[] sendBuffer;
                if (useCompressionMutation)
                    message.Encode(false, out sendBuffer, true);
                else
                    sendBuffer = originalUdpMessage.Buffer;

                _transactionClients[message.TransactionID] = originalUdpMessage.RemoteEndPoint;

                // Send to Forwarder
                await _udpForwarder.SendAsync(sendBuffer, sendBuffer.Length, targetNameServer);

                if (_transactionTimeoutCancellationTokenSources.ContainsKey(message.TransactionID))
                    _transactionTimeoutCancellationTokenSources[message.TransactionID].Cancel();
                var cancellationTokenSource = new CancellationTokenSource();
                _transactionTimeoutCancellationTokenSources[message.TransactionID] = cancellationTokenSource;

                // Timeout task to cancel the request
                try
                {
                    await Task.Delay(queryTimeout, cancellationTokenSource.Token);
                    if (!_transactionClients.ContainsKey(message.TransactionID)) return;
                    IPEndPoint ignoreEndPoint;
                    CancellationTokenSource ignoreTokenSource;
                    _transactionClients.TryRemove(message.TransactionID, out ignoreEndPoint);
                    _transactionTimeoutCancellationTokenSources.TryRemove(message.TransactionID,
                        out ignoreTokenSource);

                    var warningText = message.Questions.Count > 0
                        ? $"{message.Questions[0].Name} (Type {message.Questions[0].RecordType})"
                        : $"Transaction #{message.TransactionID}";
                    Logger.Warning("Query timeout for: {0}", warningText);
                }
                catch (TaskCanceledException)
                {
                }
            }
            catch (InfiniteForwardingException e)
            {
                Logger.Warning("[Forwarder.Send] Infinite forwarding detected for: {0} (Type {1})", e.Question.Name,
                    e.Question.RecordType);
                Utils.ReturnDnsMessageServerFailure(message, out responseBuffer);
            }
            catch (SocketException e)
            {
                if (e.SocketErrorCode == SocketError.ConnectionReset) // Target name server port unreachable
                    Logger.Warning("[Forwarder.Send] Name server port unreachable: {0}", targetNameServer);
                else
                    Logger.Error("[Forwarder.Send] Unhandled socket error: {0}", e.Message);
                Utils.ReturnDnsMessageServerFailure(message, out responseBuffer);
            }
            catch (Exception e)
            {
                Logger.Error("[Forwarder] Unexpected exception:\n{0}", e);
                Utils.ReturnDnsMessageServerFailure(message, out responseBuffer);
            }

            // If we got some errors
            if (responseBuffer != null)
                await _udpListener.SendAsync(responseBuffer, responseBuffer.Length, originalUdpMessage.RemoteEndPoint);
        }
Beispiel #12
0
        private async Task <DnsSecResult <T> > ResolveAsyncInternal <T>(DomainName name, RecordType recordType, RecordClass recordClass, State state, CancellationToken token)
            where T : DnsRecordBase
        {
            DnsCacheRecordList <T> cachedResults;

            if (_cache.TryGetRecords(name, recordType, recordClass, out cachedResults))
            {
                return(new DnsSecResult <T>(cachedResults, cachedResults.ValidationResult));
            }

            DnsCacheRecordList <CNameRecord> cachedCNames;

            if (_cache.TryGetRecords(name, RecordType.CName, recordClass, out cachedCNames))
            {
                var cNameResult = await ResolveAsyncInternal <T>(cachedCNames.First().CanonicalName, recordType, recordClass, state, token);

                return(new DnsSecResult <T>(cNameResult.Records, cachedCNames.ValidationResult == cNameResult.ValidationResult ? cachedCNames.ValidationResult : DnsSecValidationResult.Unsigned));
            }

            DnsMessage msg = await ResolveMessageAsync(name, recordType, recordClass, state, token);

            // check for cname
            List <DnsRecordBase> cNameRecords = msg.AnswerRecords.Where(x => (x.RecordType == RecordType.CName) && (x.RecordClass == recordClass) && x.Name.Equals(name)).ToList();

            if (cNameRecords.Count > 0)
            {
                DnsSecValidationResult cNameValidationResult = await _validator.ValidateAsync(name, RecordType.CName, recordClass, msg, cNameRecords, state, token);

                if ((cNameValidationResult == DnsSecValidationResult.Bogus) || (cNameValidationResult == DnsSecValidationResult.Indeterminate))
                {
                    throw new DnsSecValidationException("CNAME record could not be validated");
                }

                _cache.Add(name, RecordType.CName, recordClass, cNameRecords, cNameValidationResult, cNameRecords.Min(x => x.TimeToLive));

                DomainName canonicalName = ((CNameRecord)cNameRecords.First()).CanonicalName;

                List <DnsRecordBase> matchingAdditionalRecords = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClass == recordClass) && x.Name.Equals(canonicalName)).ToList();
                if (matchingAdditionalRecords.Count > 0)
                {
                    DnsSecValidationResult matchingValidationResult = await _validator.ValidateAsync(canonicalName, recordType, recordClass, msg, matchingAdditionalRecords, state, token);

                    if ((matchingValidationResult == DnsSecValidationResult.Bogus) || (matchingValidationResult == DnsSecValidationResult.Indeterminate))
                    {
                        throw new DnsSecValidationException("CNAME matching records could not be validated");
                    }

                    DnsSecValidationResult validationResult = cNameValidationResult == matchingValidationResult ? cNameValidationResult : DnsSecValidationResult.Unsigned;
                    _cache.Add(canonicalName, recordType, recordClass, matchingAdditionalRecords, validationResult, matchingAdditionalRecords.Min(x => x.TimeToLive));

                    return(new DnsSecResult <T>(matchingAdditionalRecords.OfType <T>().ToList(), validationResult));
                }

                var cNameResults = await ResolveAsyncInternal <T>(canonicalName, recordType, recordClass, state, token);

                return(new DnsSecResult <T>(cNameResults.Records, cNameValidationResult == cNameResults.ValidationResult ? cNameValidationResult : DnsSecValidationResult.Unsigned));
            }

            // check for "normal" answer
            List <DnsRecordBase> answerRecords = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClass == recordClass) && x.Name.Equals(name)).ToList();

            if (answerRecords.Count > 0)
            {
                DnsSecValidationResult validationResult = await _validator.ValidateAsync(name, recordType, recordClass, msg, answerRecords, state, token);

                if ((validationResult == DnsSecValidationResult.Bogus) || (validationResult == DnsSecValidationResult.Indeterminate))
                {
                    throw new DnsSecValidationException("Response records could not be validated");
                }

                _cache.Add(name, recordType, recordClass, answerRecords, validationResult, answerRecords.Min(x => x.TimeToLive));
                return(new DnsSecResult <T>(answerRecords.OfType <T>().ToList(), validationResult));
            }

            // check for negative answer
            SoaRecord soaRecord = msg.AuthorityRecords
                                  .Where(x =>
                                         (x.RecordType == RecordType.Soa) &&
                                         (name.Equals(x.Name) || name.IsSubDomainOf(x.Name)))
                                  .OfType <SoaRecord>()
                                  .FirstOrDefault();

            if (soaRecord != null)
            {
                DnsSecValidationResult validationResult = await _validator.ValidateAsync(name, recordType, recordClass, msg, answerRecords, state, token);

                if ((validationResult == DnsSecValidationResult.Bogus) || (validationResult == DnsSecValidationResult.Indeterminate))
                {
                    throw new DnsSecValidationException("Negative answer could not be validated");
                }

                _cache.Add(name, recordType, recordClass, new List <DnsRecordBase>(), validationResult, soaRecord.NegativeCachingTTL);
                return(new DnsSecResult <T>(new List <T>(), validationResult));
            }

            // authoritive response does not contain answer
            throw new Exception("Could not resolve " + name);
        }
 private DnsMessage mockResolveDns(string name, RecordType recordType, RecordClass recordClass)
 {
     DnsMessage answer = null;
     bool any = recordType == RecordType.Any;
     if ((any || recordType == RecordType.A) && dnsMockRecords.ContainsKey(name))
     {
         var ip = dnsMockRecords[name];
         answer = new DnsMessage();
         answer.AnswerRecords.Add(new ARecord(name, 0, ip));
     }
     return answer;
 }
Beispiel #14
0
        public async Task <ArraySegment <byte>?> HandleUdpMessage(IPEndPoint remoteEp, byte[] buffer)
        {
            try {
                ClientConnectedEventArgs clientConnectedEventArgs = new ClientConnectedEventArgs(ProtocolType.Udp, remoteEp);
                await ClientConnected.RaiseAsync(this, clientConnectedEventArgs);

                if (clientConnectedEventArgs.RefuseConnect)
                {
                    return(null);
                }

                DnsMessageBase query;
                byte[]         originalMac;
                try {
                    query       = DnsMessageBase.CreateByFlag(buffer, TsigKeySelector, null);
                    originalMac = query.TSigOptions?.Mac;
                } catch (Exception e) {
                    throw new Exception("Error parsing dns query", e);
                }

                DnsMessageBase response;
                try {
                    response = await ProcessMessageAsync(query, ProtocolType.Udp, remoteEp);
                } catch (Exception ex) {
                    OnExceptionThrownAsync(ex);
                    response = null;
                }

                if (response == null)
                {
                    response         = query;
                    query.IsQuery    = false;
                    query.ReturnCode = ReturnCode.ServerFailure;
                }

                int length = response.Encode(false, originalMac, out buffer);

                #region Truncating
                DnsMessage message = response as DnsMessage;

                if (message != null)
                {
                    int maxLength = 512;
                    if (query.IsEDnsEnabled && message.IsEDnsEnabled)
                    {
                        maxLength = Math.Max(512, (int)message.EDnsOptions.UdpPayloadSize);
                    }

                    while (length > maxLength)
                    {
                        // First step: remove data from additional records except the opt record
                        if ((message.IsEDnsEnabled && (message.AdditionalRecords.Count > 1)) || (!message.IsEDnsEnabled && (message.AdditionalRecords.Count > 0)))
                        {
                            for (int i = message.AdditionalRecords.Count - 1; i >= 0; i--)
                            {
                                if (message.AdditionalRecords[i].RecordType != RecordType.Opt)
                                {
                                    message.AdditionalRecords.RemoveAt(i);
                                }
                            }

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        int savedLength = 0;
                        if (message.AuthorityRecords.Count > 0)
                        {
                            for (int i = message.AuthorityRecords.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.AuthorityRecords[i].MaximumLength;
                                message.AuthorityRecords.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        if (message.AnswerRecords.Count > 0)
                        {
                            for (int i = message.AnswerRecords.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.AnswerRecords[i].MaximumLength;
                                message.AnswerRecords.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        if (message.Questions.Count > 0)
                        {
                            for (int i = message.Questions.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.Questions[i].MaximumLength;
                                message.Questions.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                        }
                    }
                }
                #endregion

                return(new ArraySegment <byte>(buffer, 0, length));
            } catch (Exception ex) {
                OnExceptionThrownAsync(ex);
                return(null);
            }
        }
Beispiel #15
0
 private static async Task ResolveWithHttp(string targetNameServer, string domainName, int timeout, DnsMessage message)
 {
     var request = WebRequest.Create($"http://{targetNameServer}/d?dn={domainName}&ttl=1");
     request.Timeout = timeout;
     var stream = (await request.GetResponseAsync()).GetResponseStream();
     if (stream == null)
         throw new Exception("Invalid HTTP response stream.");
     using (var reader = new StreamReader(stream))
     {
         var result = await reader.ReadToEndAsync();
         if (string.IsNullOrEmpty(result))
         {
             message.ReturnCode = ReturnCode.NxDomain;
             message.IsQuery = false;
         }
         else
         {
             var parts = result.Split(',');
             var ips = parts[0].Split(';');
             foreach (var ip in ips)
             {
                 message.AnswerRecords.Add(new ARecord(domainName, int.Parse(parts[1]), IPAddress.Parse(ip)));
             }
             message.ReturnCode = ReturnCode.NoError;
             message.IsQuery = false;
         }
     }
 }
Beispiel #16
0
		/// <summary>
		///   Queries a dns server for specified records.
		/// </summary>
		/// <param name="name"> Domain, that should be queried </param>
		/// <param name="recordType"> Type the should be queried </param>
		/// <param name="recordClass"> Class the should be queried </param>
		/// <returns> The complete response of the dns server </returns>
		public DnsMessage Resolve(string name, RecordType recordType, RecordClass recordClass)
		{
			if (String.IsNullOrEmpty(name))
			{
				throw new ArgumentException("Name must be provided", "name");
			}

			DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };
			message.Questions.Add(new DnsQuestion(name, recordType, recordClass));

			return SendMessage(message);
		}
		/// <summary>
		///   Creates a new instance of the DnsMessage as response to the current instance
		/// </summary>
		/// <returns>A new instance of the DnsMessage as response to the current instance</returns>
		public DnsMessage CreateResponseInstance()
		{
			DnsMessage result = new DnsMessage()
			{
				TransactionID = TransactionID,
				IsEDnsEnabled = IsEDnsEnabled,
				IsQuery = false,
				OperationCode = OperationCode,
				IsRecursionDesired = IsRecursionDesired,
				IsCheckingDisabled = IsCheckingDisabled,
				IsDnsSecOk = IsDnsSecOk,
				Questions = new List<DnsQuestion>(Questions),
			};

			if (IsEDnsEnabled)
			{
				result.EDnsOptions.Version = EDnsOptions.Version;
				result.EDnsOptions.UdpPayloadSize = EDnsOptions.UdpPayloadSize;
			}

			return result;
		}
		/// <summary>
		///   Send a custom message to the dns server and returns the answer.
		/// </summary>
		/// <param name="message"> Message, that should be send to the dns server </param>
		/// <returns> The complete response of the dns server </returns>
		public DnsMessage SendMessage(DnsMessage message)
		{
			if (message == null)
				throw new ArgumentNullException(nameof(message));

			if ((message.Questions == null) || (message.Questions.Count == 0))
				throw new ArgumentException("At least one question must be provided", nameof(message));

			return SendMessage<DnsMessage>(message);
		}
		/// <summary>
		///   Send a custom message to the dns server and returns the answer as an asynchronous operation.
		/// </summary>
		/// <param name="message"> Message, that should be send to the dns server </param>
		/// <param name="token"> The token to monitor cancellation requests </param>
		/// <returns> The complete response of the dns server </returns>
		public Task<DnsMessage> SendMessageAsync(DnsMessage message, CancellationToken token = default(CancellationToken))
		{
			if (message == null)
				throw new ArgumentNullException(nameof(message));

			if ((message.Questions == null) || (message.Questions.Count == 0))
				throw new ArgumentException("At least one question must be provided", nameof(message));

			return SendMessageAsync<DnsMessage>(message, token);
		}
		public static DnsMessageBase Create(byte[] resultData, bool isRequest, DnsServer.SelectTsigKey tsigKeySelector, byte[] originalMac)
		{
			int flagPosition = 2;
			ushort flags = ParseUShort(resultData, ref flagPosition);

			DnsMessageBase res;

			switch ((OperationCode) ((flags & 0x7800) >> 11))
			{
				case OperationCode.Update:
					res = new DnsUpdateMessage();
					break;

				default:
					res = new DnsMessage();
					break;
			}

			res.Parse(resultData, isRequest, tsigKeySelector, originalMac);

			return res;
		}
Beispiel #21
0
        /// <summary>
        ///   Resolves specified records as an asynchronous operation.
        /// </summary>
        /// <typeparam name="T"> Type of records, that should be returned </typeparam>
        /// <param name="name"> Domain, that should be queried </param>
        /// <param name="recordType"> Type the should be queried </param>
        /// <param name="recordClass"> Class the should be queried </param>
        /// <param name="token"> The token to monitor cancellation requests </param>
        /// <returns> A list of matching <see cref="DnsRecordBase">records</see> </returns>
        public async Task <DnsSecResult <T> > ResolveSecureAsync <T>(DomainName name, RecordType recordType = RecordType.A, RecordClass recordClass = RecordClass.INet, CancellationToken token = default(CancellationToken))
            where T : DnsRecordBase
        {
            if (name == null)
            {
                throw new ArgumentNullException(nameof(name), "Name must be provided");
            }

            DnsCacheRecordList <T> cacheResult;

            if (_cache.TryGetRecords(name, recordType, recordClass, out cacheResult))
            {
                return(new DnsSecResult <T>(cacheResult, cacheResult.ValidationResult));
            }

            DnsMessage msg = await _dnsClient.ResolveAsync(name, recordType, recordClass, new DnsQueryOptions()
            {
                IsEDnsEnabled      = true,
                IsDnsSecOk         = true,
                IsCheckingDisabled = true,
                IsRecursionDesired = true
            }, token).ConfigureAwait(false);

            if ((msg == null) || ((msg.ReturnCode != ReturnCode.NoError) && (msg.ReturnCode != ReturnCode.NxDomain)))
            {
                throw new Exception("DNS request failed");
            }

            DnsSecValidationResult validationResult;

            CNameRecord cName = msg.AnswerRecords.Where(x => (x.RecordType == RecordType.CName) && (x.RecordClass == recordClass) && x.Name.Equals(name)).OfType <CNameRecord>().FirstOrDefault();

            if (cName != null)
            {
                DnsSecValidationResult cNameValidationResult = await _validator.ValidateAsync(name, RecordType.CName, recordClass, msg, new List <CNameRecord>() { cName }, null, token).ConfigureAwait(false);

                if ((cNameValidationResult == DnsSecValidationResult.Bogus) || (cNameValidationResult == DnsSecValidationResult.Indeterminate))
                {
                    throw new DnsSecValidationException("CNAME record could not be validated");
                }

                var records = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClass == recordClass) && x.Name.Equals(cName.CanonicalName)).OfType <T>().ToList();
                if (records.Count > 0)
                {
                    DnsSecValidationResult recordsValidationResult = await _validator.ValidateAsync(cName.CanonicalName, recordType, recordClass, msg, records, null, token).ConfigureAwait(false);

                    if ((recordsValidationResult == DnsSecValidationResult.Bogus) || (recordsValidationResult == DnsSecValidationResult.Indeterminate))
                    {
                        throw new DnsSecValidationException("CNAME matching records could not be validated");
                    }

                    validationResult = cNameValidationResult == recordsValidationResult ? cNameValidationResult : DnsSecValidationResult.Unsigned;
                    _cache.Add(name, recordType, recordClass, records, validationResult, Math.Min(cName.TimeToLive, records.Min(x => x.TimeToLive)));

                    return(new DnsSecResult <T>(records, validationResult));
                }

                var cNameResults = await ResolveSecureAsync <T>(cName.CanonicalName, recordType, recordClass, token).ConfigureAwait(false);

                validationResult = cNameValidationResult == cNameResults.ValidationResult ? cNameValidationResult : DnsSecValidationResult.Unsigned;

                if (cNameResults.Records.Count > 0)
                {
                    _cache.Add(name, recordType, recordClass, cNameResults.Records, validationResult, Math.Min(cName.TimeToLive, cNameResults.Records.Min(x => x.TimeToLive)));
                }

                return(new DnsSecResult <T>(cNameResults.Records, validationResult));
            }

            List <T> res = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClass == recordClass) && x.Name.Equals(name)).OfType <T>().ToList();

            validationResult = await _validator.ValidateAsync(name, recordType, recordClass, msg, res, null, token).ConfigureAwait(false);

            if ((validationResult == DnsSecValidationResult.Bogus) || (validationResult == DnsSecValidationResult.Indeterminate))
            {
                throw new DnsSecValidationException("Response records could not be validated");
            }

            if (res.Count > 0)
            {
                _cache.Add(name, recordType, recordClass, res, validationResult, res.Min(x => x.TimeToLive));
            }

            return(new DnsSecResult <T>(res, validationResult));
        }
        private async Task <List <T> > ResolveAsyncInternal <T>(DomainName name, RecordType recordType, RecordClass recordClass, State state, CancellationToken token)
            where T : DnsRecordBase
        {
            List <T> cachedResults;

            if (_cache.TryGetRecords(name, recordType, recordClass, out cachedResults))
            {
                return(cachedResults);
            }

            List <CNameRecord> cachedCNames;

            if (_cache.TryGetRecords(name, RecordType.CName, recordClass, out cachedCNames))
            {
                return(await ResolveAsyncInternal <T>(cachedCNames.First().CanonicalName, recordType, recordClass, state, token));
            }

            DnsMessage msg = await ResolveMessageAsync(name, recordType, recordClass, state, token);

            // check for cname
            List <DnsRecordBase> cNameRecords = msg.AnswerRecords.Where(x => (x.RecordType == RecordType.CName) && (x.RecordClass == recordClass) && x.Name.Equals(name)).ToList();

            if (cNameRecords.Count > 0)
            {
                _cache.Add(name, RecordType.CName, recordClass, cNameRecords, DnsSecValidationResult.Indeterminate, cNameRecords.Min(x => x.TimeToLive));

                DomainName canonicalName = ((CNameRecord)cNameRecords.First()).CanonicalName;

                List <DnsRecordBase> matchingAdditionalRecords = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClass == recordClass) && x.Name.Equals(canonicalName)).ToList();
                if (matchingAdditionalRecords.Count > 0)
                {
                    _cache.Add(canonicalName, recordType, recordClass, matchingAdditionalRecords, DnsSecValidationResult.Indeterminate, matchingAdditionalRecords.Min(x => x.TimeToLive));
                    return(matchingAdditionalRecords.OfType <T>().ToList());
                }

                return(await ResolveAsyncInternal <T>(canonicalName, recordType, recordClass, state, token));
            }

            // check for "normal" answer
            List <DnsRecordBase> answerRecords = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClass == recordClass) && x.Name.Equals(name)).ToList();

            if (answerRecords.Count > 0)
            {
                _cache.Add(name, recordType, recordClass, answerRecords, DnsSecValidationResult.Indeterminate, answerRecords.Min(x => x.TimeToLive));
                return(answerRecords.OfType <T>().ToList());
            }

            // check for negative answer
            SoaRecord soaRecord = msg.AuthorityRecords
                                  .Where(x =>
                                         (x.RecordType == RecordType.Soa) &&
                                         (name.Equals(x.Name) || name.IsSubDomainOf(x.Name)))
                                  .OfType <SoaRecord>()
                                  .FirstOrDefault();

            if (soaRecord != null)
            {
                _cache.Add(name, recordType, recordClass, new List <DnsRecordBase>(), DnsSecValidationResult.Indeterminate, soaRecord.NegativeCachingTTL);
                return(new List <T>());
            }

            // authoritive response does not contain answer
            throw new Exception("Could not resolve " + name);
        }
Beispiel #23
0
		/// <summary>
		///   Send a custom message to the dns server and returns the answer asynchronously.
		/// </summary>
		/// <param name="message"> Message, that should be send to the dns server </param>
		/// <param name="requestCallback">
		///   An <see cref="System.AsyncCallback" /> delegate that references the method to invoked then the operation is complete.
		/// </param>
		/// <param name="state">
		///   A user-defined object that contains information about the receive operation. This object is passed to the
		///   <paramref
		///     name="requestCallback" />
		///   delegate when the operation is complete.
		/// </param>
		/// <returns>
		///   An <see cref="System.IAsyncResult" /> IAsyncResult object that references the asynchronous receive.
		/// </returns>
		public IAsyncResult BeginSendMessage(DnsMessage message, AsyncCallback requestCallback, object state)
		{
			if (message == null)
				throw new ArgumentNullException("message");

			if ((message.Questions == null) || (message.Questions.Count == 0))
				throw new ArgumentException("At least one question must be provided", "message");

			return BeginSendMessage<DnsMessage>(message, requestCallback, state);
		}
Beispiel #24
0
		/// <summary>
		///   Queries a dns server for specified records asynchronously.
		/// </summary>
		/// <param name="name"> Domain, that should be queried </param>
		/// <param name="recordType"> Type the should be queried </param>
		/// <param name="recordClass"> Class the should be queried </param>
		/// <param name="requestCallback">
		///   An <see cref="System.AsyncCallback" /> delegate that references the method to invoked then the operation is complete.
		/// </param>
		/// <param name="state">
		///   A user-defined object that contains information about the receive operation. This object is passed to the
		///   <paramref
		///     name="requestCallback" />
		///   delegate when the operation is complete.
		/// </param>
		/// <returns>
		///   An <see cref="System.IAsyncResult" /> IAsyncResult object that references the asynchronous receive.
		/// </returns>
		public IAsyncResult BeginResolve(string name, RecordType recordType, RecordClass recordClass, AsyncCallback requestCallback, object state)
		{
			if (String.IsNullOrEmpty(name))
			{
				throw new ArgumentException("Name must be provided", "name");
			}

			DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };
			message.Questions.Add(new DnsQuestion(name, recordType, recordClass));

			return BeginSendMessage(message, requestCallback, state);
		}
Beispiel #25
0
        private void EndUdpReceive(IAsyncResult ar)
        {
            try
            {
                lock (_udpListener)
                {
                    _hasActiveUdpListener = false;
                }
                StartUdpListen();

                IPEndPoint endpoint;

                byte[] buffer = _udpListener.EndReceive(ar, out endpoint);

                DnsMessageBase query;
                byte[]         originalMac;
                try
                {
                    query       = DnsMessageBase.CreateByFlag(buffer, TsigKeySelector, null);
                    originalMac = (query.TSigOptions == null) ? null : query.TSigOptions.Mac;
                }
                catch (Exception e)
                {
                    throw new Exception("Error parsing dns query", e);
                }

                DnsMessageBase response;
                try
                {
                    response = ProcessMessage(query, endpoint.Address, ProtocolType.Udp);
                }
                catch (Exception ex)
                {
                    OnExceptionThrown(ex);
                    response = null;
                }

                if (response == null)
                {
                    response         = query;
                    query.IsQuery    = false;
                    query.ReturnCode = ReturnCode.ServerFailure;
                }

                int length = response.Encode(false, originalMac, out buffer);

                #region Truncating
                DnsMessage message = response as DnsMessage;

                if (message != null)
                {
                    int maxLength = 512;
                    if (query.IsEDnsEnabled && message.IsEDnsEnabled)
                    {
                        maxLength = Math.Max(512, (int)message.EDnsOptions.UdpPayloadSize);
                    }

                    while (length > maxLength)
                    {
                        // First step: remove data from additional records except the opt record
                        if ((message.IsEDnsEnabled && (message.AdditionalRecords.Count > 1)) || (!message.IsEDnsEnabled && (message.AdditionalRecords.Count > 0)))
                        {
                            for (int i = message.AdditionalRecords.Count - 1; i >= 0; i--)
                            {
                                if (message.AdditionalRecords[i].RecordType != RecordType.Opt)
                                {
                                    message.AdditionalRecords.RemoveAt(i);
                                }
                            }

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        int savedLength = 0;
                        if (message.AuthorityRecords.Count > 0)
                        {
                            for (int i = message.AuthorityRecords.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.AuthorityRecords[i].MaximumLength;
                                message.AuthorityRecords.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        if (message.AnswerRecords.Count > 0)
                        {
                            for (int i = message.AnswerRecords.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.AnswerRecords[i].MaximumLength;
                                message.AnswerRecords.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        if (message.Questions.Count > 0)
                        {
                            for (int i = message.Questions.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.Questions[i].MaximumLength;
                                message.Questions.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                        }
                    }
                }
                #endregion

                _udpListener.BeginSend(buffer, 0, length, endpoint, EndUdpSend, null);
            }
            catch (Exception e)
            {
                HandleUdpException(e);
            }
        }
		/// <summary>
		///   Queries a dns server for specified records.
		/// </summary>
		/// <param name="name"> Domain, that should be queried </param>
		/// <param name="recordType"> Type the should be queried </param>
		/// <param name="recordClass"> Class the should be queried </param>
		/// <param name="options"> Options for the query </param>
		/// <returns> The complete response of the dns server </returns>
		public DnsMessage Resolve(DomainName name, RecordType recordType = RecordType.A, RecordClass recordClass = RecordClass.INet, DnsQueryOptions options = null)
		{
			if (name == null)
				throw new ArgumentNullException(nameof(name), "Name must be provided");

			DnsMessage message = new DnsMessage() { IsQuery = true, OperationCode = OperationCode.Query, IsRecursionDesired = true, IsEDnsEnabled = true };

			if (options == null)
			{
				message.IsRecursionDesired = true;
				message.IsEDnsEnabled = true;
			}
			else
			{
				message.IsRecursionDesired = options.IsRecursionDesired;
				message.IsCheckingDisabled = options.IsCheckingDisabled;
				message.EDnsOptions = options.EDnsOptions;
			}

			message.Questions.Add(new DnsQuestion(name, recordType, recordClass));

			return SendMessage(message);
		}
        private async Task <DnsMessage> ResolveMessageAsync(DomainName name, RecordType recordType, RecordClass recordClass, State state, CancellationToken token)
        {
            for (; state.QueryCount <= MaximumReferalCount; state.QueryCount++)
            {
                DnsMessage msg = await new DnsClient(GetBestNameservers(recordType == RecordType.Ds ? name.GetParentName() : name), QueryTimeout)
                {
                    IsResponseValidationEnabled = IsResponseValidationEnabled,
                    Is0x20ValidationEnabled     = Is0x20ValidationEnabled
                }.ResolveAsync(name, recordType, recordClass, new DnsQueryOptions()
                {
                    IsRecursionDesired = false,
                    IsEDnsEnabled      = true
                }, token);

                if ((msg != null) && ((msg.ReturnCode == ReturnCode.NoError) || (msg.ReturnCode == ReturnCode.NxDomain)))
                {
                    if (msg.IsAuthoritiveAnswer)
                    {
                        return(msg);
                    }

                    List <NsRecord> referalRecords = msg.AuthorityRecords
                                                     .Where(x =>
                                                            (x.RecordType == RecordType.Ns) &&
                                                            (name.Equals(x.Name) || name.IsSubDomainOf(x.Name)))
                                                     .OfType <NsRecord>()
                                                     .ToList();

                    if (referalRecords.Count > 0)
                    {
                        if (referalRecords.GroupBy(x => x.Name).Count() == 1)
                        {
                            var newServers = referalRecords.Join(msg.AdditionalRecords.OfType <AddressRecordBase>(), x => x.NameServer, x => x.Name, (x, y) => new { y.Address, TimeToLive = Math.Min(x.TimeToLive, y.TimeToLive) }).ToList();

                            if (newServers.Count > 0)
                            {
                                DomainName zone = referalRecords.First().Name;

                                foreach (var newServer in newServers)
                                {
                                    _nameserverCache.Add(zone, newServer.Address, newServer.TimeToLive);
                                }

                                continue;
                            }
                            else
                            {
                                NsRecord firstReferal = referalRecords.First();

                                var newLookedUpServers = await ResolveHostWithTtlAsync(firstReferal.NameServer, state, token);

                                foreach (var newServer in newLookedUpServers)
                                {
                                    _nameserverCache.Add(firstReferal.Name, newServer.Item1, Math.Min(firstReferal.TimeToLive, newServer.Item2));
                                }

                                if (newLookedUpServers.Count > 0)
                                {
                                    continue;
                                }
                            }
                        }
                    }

                    // Response of best known server is not authoritive and has no referrals --> No chance to get a result
                    throw new Exception("Could not resolve " + name);
                }
            }

            // query limit reached without authoritive answer
            throw new Exception("Could not resolve " + name);
        }
        private DnsMessage GetDotBitAnswerForName(DnsQuestion question, string name)
        {
            try
            {
                recursionLevel++;

                if (recursionLevel > maxRecursion)
                {
                    ConsoleUtils.WriteWarning("Max recursion reached");
                    return null;
                }

                DomainValue value = GetDomainValue(name);
                if (value == null)
                    return null;

                value.ImportDefaultMap();

                DnsMessage answer = null;

                //TODO: delegate not implemented
                if (!string.IsNullOrWhiteSpace(value.@Delegate))
                    ConsoleUtils.WriteWarning("delegate setting not implemented: {0}", value.Import);

                //TODO: import not implemented
                if (!string.IsNullOrWhiteSpace(value.Import))
                    ConsoleUtils.WriteWarning("import setting not implemented: {0}", value.Import);

                if (value.Alias != null)
                {
                    string newLookup;
                    if (value.Alias.EndsWith(".")) // absolute
                    {
                        newLookup = value.Alias;
                    }
                    else // sub domain
                    {
                        newLookup = value.Alias + '.';
                    }
                    DnsQuestion newQuestion = new DnsQuestion(value.Alias, question.RecordType, question.RecordClass);
                    return InternalGetAnswer(newQuestion);
                }

                answer = new DnsMessage()
                {
                    Questions = new List<DnsQuestion>() { question }
                };

                bool any = question.RecordType == RecordType.Any;

                var nsnames = value.Ns;
                if (nsnames != null && nsnames.Count() > 0) // NS overrides all
                {
                    List<IPAddress> nameservers = GetDotBitNameservers(nsnames);
                    if (nameservers.Count() > 0)
                    {
                        var client = new DnsClient(nameservers, 2000);
                        if (!string.IsNullOrWhiteSpace(value.Translate))
                            name = value.Translate;
                        answer = client.Resolve(name, question.RecordType, question.RecordClass);
                    }
                }
                else
                {
                    if (any || question.RecordType == RecordType.A)
                    {
                        var addresses = value.Ips;
                        if (addresses.Count() > 0)
                            foreach (var address in addresses)
                                answer.AnswerRecords.Add(new ARecord(name, 60, address));
                    }
                    if (any || question.RecordType == RecordType.Aaaa)
                    {
                        var addresses = value.Ip6s;
                        if (addresses.Count() > 0)
                            foreach (var address in addresses)
                                answer.AnswerRecords.Add(new AaaaRecord(name, 60, address));
                    }
                }

                return answer;
            }
            finally
            {
                recursionLevel--;
            }
        }
Beispiel #29
0
        private async void HandleUdpListenerAsync()
        {
            try
            {
                UdpReceiveResult receiveResult;
                try
                {
                    receiveResult = await _udpListener.ReceiveAsync();
                }
                catch (ObjectDisposedException)
                {
                    return;
                }
                finally
                {
                    lock (_listenerLock)
                    {
                        _hasActiveUdpListener = false;
                    }
                }

                ClientConnectedEventArgs clientConnectedEventArgs = new ClientConnectedEventArgs(ProtocolType.Udp, receiveResult.RemoteEndPoint);
                await ClientConnected.RaiseAsync(this, clientConnectedEventArgs);

                if (clientConnectedEventArgs.RefuseConnect)
                {
                    return;
                }

                StartUdpListenerTask();

                byte[] buffer = receiveResult.Buffer;

                DnsMessageBase query;
                byte[]         originalMac;
                try
                {
                    query       = DnsMessageBase.CreateByFlag(buffer, TsigKeySelector, null);
                    originalMac = query.TSigOptions?.Mac;
                }
                catch (Exception e)
                {
                    throw new Exception("Error parsing dns query", e);
                }

                DnsMessageBase response;
                try
                {
                    response = await ProcessMessageAsync(query, ProtocolType.Udp, receiveResult.RemoteEndPoint);
                }
                catch (Exception ex)
                {
                    OnExceptionThrownAsync(ex);
                    response = null;
                }

                if (response == null)
                {
                    response         = query;
                    query.IsQuery    = false;
                    query.ReturnCode = ReturnCode.ServerFailure;
                }

                int length = response.Encode(false, originalMac, out buffer);

                #region Truncating
                DnsMessage message = response as DnsMessage;

                if (message != null)
                {
                    int maxLength = 512;
                    if (query.IsEDnsEnabled && message.IsEDnsEnabled)
                    {
                        maxLength = Math.Max(512, (int)message.EDnsOptions.UdpPayloadSize);
                    }

                    while (length > maxLength)
                    {
                        // First step: remove data from additional records except the opt record
                        if ((message.IsEDnsEnabled && (message.AdditionalRecords.Count > 1)) || (!message.IsEDnsEnabled && (message.AdditionalRecords.Count > 0)))
                        {
                            for (int i = message.AdditionalRecords.Count - 1; i >= 0; i--)
                            {
                                if (message.AdditionalRecords[i].RecordType != RecordType.Opt)
                                {
                                    message.AdditionalRecords.RemoveAt(i);
                                }
                            }

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        int savedLength = 0;
                        if (message.AuthorityRecords.Count > 0)
                        {
                            for (int i = message.AuthorityRecords.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.AuthorityRecords[i].MaximumLength;
                                message.AuthorityRecords.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        if (message.AnswerRecords.Count > 0)
                        {
                            for (int i = message.AnswerRecords.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.AnswerRecords[i].MaximumLength;
                                message.AnswerRecords.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                            continue;
                        }

                        if (message.Questions.Count > 0)
                        {
                            for (int i = message.Questions.Count - 1; i >= 0; i--)
                            {
                                savedLength += message.Questions[i].MaximumLength;
                                message.Questions.RemoveAt(i);

                                if ((length - savedLength) < maxLength)
                                {
                                    break;
                                }
                            }

                            message.IsTruncated = true;

                            length = message.Encode(false, originalMac, out buffer);
                        }
                    }
                }
                #endregion

                await _udpListener.SendAsync(buffer, length, receiveResult.RemoteEndPoint);
            }
            catch (Exception ex)
            {
                OnExceptionThrownAsync(ex);
            }
            finally
            {
                lock (_listenerLock)
                {
                    _availableUdpListener++;
                }
                StartUdpListenerTask();
            }
        }
 private bool IsFailedQuery(DnsMessage message)
 {
     return message == null || (message.ReturnCode != ReturnCode.NoError && message.ReturnCode != ReturnCode.NxDomain);
 }
        public async Task <DnsSecValidationResult> ValidateAsync <TRecord>(DomainName name, RecordType recordType, RecordClass recordClass, DnsMessage msg, List <TRecord> resultRecords, TState state, CancellationToken token)
            where TRecord : DnsRecordBase
        {
            List <RrSigRecord> rrSigRecords = msg
                                              .AnswerRecords.OfType <RrSigRecord>()
                                              .Union(msg.AuthorityRecords.OfType <RrSigRecord>())
                                              .Where(x => name.IsEqualOrSubDomainOf(x.SignersName) && (x.SignatureInception <= DateTime.Now) && (x.SignatureExpiration >= DateTime.Now)).ToList();

            if (rrSigRecords.Count == 0)
            {
                return(await ValidateOptOut(name, recordClass, state, token) ? DnsSecValidationResult.Unsigned : DnsSecValidationResult.Bogus);
            }

            DomainName zoneApex = rrSigRecords.OrderByDescending(x => x.Labels).First().SignersName;

            if (resultRecords.Count != 0)
            {
                return(await ValidateRrSigAsync(name, recordType, recordClass, resultRecords, rrSigRecords, zoneApex, msg, state, token));
            }

            return(await ValidateNonExistenceAsync(name, recordType, recordClass, rrSigRecords, DomainName.Asterisk + zoneApex, zoneApex, msg, state, token));
        }
Beispiel #32
0
 public static void ReturnDnsMessageServerFailure(DnsMessage message, out byte[] buffer)
 {
     message.ReturnCode = ReturnCode.ServerFailure;
     message.IsQuery = false;
     message.Encode(false, out buffer);
 }