private async Task QueryCore(
            bool async,
            IReadOnlyList <string> queries,
            PostgresCommand command,
            CancellationToken cancellationToken)
        {
            for (var i = 0; i < queries.Count; ++i)
            {
                await QueryCore(async, queries[i], command, cancellationToken)
                .ConfigureAwait(false);

                var isLast = i == queries.Count - 1;

                // This pattern is kind of silly. Ideally, it shouldn't be
                // permitted to execute multiple commands in a single query.
                if (!isLast)
                {
                    while (true)
                    {
                        var message = await ReadNextMessage(
                            async, cancellationToken)
                                      .ConfigureAwait(false);

                        if (message is CommandCompleteMessage)
                        {
                            break;
                        }
                    }
                }
            }
        }
Пример #2
0
        public PostgresDbDataReader(
            CommandBehavior behavior,
            PostgresDbConnectionBase connection,
            PostgresCommand command,
            CancellationToken cancellationToken)
        {
            _behavior          = behavior;
            _connection        = connection;
            _command           = command;
            _cancellationToken = cancellationToken;

            _behaviorCloseConnection = behavior
                                       .HasFlag(CommandBehavior.CloseConnection);
            _behaviorKeyInfo = behavior
                               .HasFlag(CommandBehavior.KeyInfo);
            _behaviorSchemaOnly = behavior
                                  .HasFlag(CommandBehavior.SchemaOnly);
            _behaviorSequentialAccess = behavior
                                        .HasFlag(CommandBehavior.SequentialAccess);
            _behaviorSinglaResult = behavior
                                    .HasFlag(CommandBehavior.SingleResult);
            _behaviorSingleRow = behavior
                                 .HasFlag(CommandBehavior.SingleRow);
        }
        internal Task Query(
            bool async,
            PostgresCommand command,
            CancellationToken cancellationToken)
        {
            // https://github.com/LuaDist/libpq/blob/4a90601e5d395da904b43116ffb3052e86bdc8ec/src/interfaces/libpq/fe-exec.c#L1365

            var parameterCount = command.Parameters.Count;

            if (parameterCount > short.MaxValue)
            {
                throw new IndexOutOfRangeException(
                          $"Too many arguments provided for query. Found {parameterCount}, limit {short.MaxValue}.");
            }

            var queries = command.GetRewrittenCommandText();

            if (queries.Count == 1)
            {
                return(QueryCore(async, queries[0], command, cancellationToken));
            }

            return(QueryCore(async, queries, command, cancellationToken));
        }
        public static unsafe IReadOnlyList <string> Perform(
            IReadOnlyList <PostgresPropertySetting> settings,
            PostgresCommand command)
        {
            var parameters = command.Parameters;
            var sql        = command.CommandText;

            if (sql == null)
            {
                return(EmptyList <string> .Value);
            }

            DemandStandardSettings(settings);

            var queries  = new List <string>();
            var sb       = StringBuilderPool.Get(sql.Length);
            var lastChar = '\0';

            fixed(char *sqlPtr = sql)
            for (var i = 0; i < sql.Length; ++i)
            {
                var chr     = sqlPtr[i];
                var nextChr = i == sql.Length - 1 ? '\0' : sqlPtr[i + 1];

                switch (chr)
                {
                // Handle strings made with quotes, 'foo' and E'foo'
                case '\'':
                    var escapedQuotes = lastChar == 'E';

                    sb.Append(chr);
                    lastChar = '\0';

                    for (++i; i < sql.Length; ++i)
                    {
                        chr = sqlPtr[i];
                        sb.Append(chr);

                        // Quotes (chr == '\'') can be inside a string
                        // in several ways:
                        // * If they're doubled up.
                        // * If we're inside an escape string escaped
                        //   with a backslash.
                        if (chr == '\'' && lastChar != '\'' &&
                            !(escapedQuotes && lastChar == '\\'))
                        {
                            goto next;
                        }

                        lastChar = chr;
                    }

                    continue;

                // Handle dollar strings, $$foo$$ and $bar$foo$bar$
                case '$':
                    var k = i + 1;

                    // Syntax is "$abc$" or "$$", if "named" then "a"
                    // must be a letter. bc+ can be letter or digit.
                    // But "$5" is also valid syntax for parameter
                    // access, so we must respect those.

                    if (k >= sql.Length)
                    {
                        goto default;
                    }

                    var chrK = sqlPtr[k];

                    if (chrK == '$')
                    {
                        var indexOf = sql.IndexOf("$$", k,
                                                  StringComparison.Ordinal);

                        // Really... it's invalid syntax.
                        if (indexOf == -1)
                        {
                            goto default;
                        }

                        // 2 is length of "$$"
                        sb.Append(sql, i, indexOf - i + 2);

                        i = indexOf + 1;
                        goto next;
                    }

                    if (!char.IsLetter(chrK))
                    {
                        goto default;
                    }

                    sb.Append('$');
                    sb.Append(chrK);

                    for (++k; k < sql.Length; ++k)
                    {
                        chrK = sqlPtr[k];
                        sb.Append(chrK);
                        if (chrK == '$')
                        {
                            break;
                        }
                        if (!char.IsLetterOrDigit(chrK))
                        {
                            goto default;
                        }
                    }

                    // +1 to account for final $.
                    ++k;

                    var namedStringStart = i;
                    var matchIndex       = namedStringStart;
                    var matchedCount     = 0;
                    var matchLength      = k - namedStringStart;

                    for (i = k; i < sql.Length; ++i)
                    {
                        for (var m = i; m < sql.Length; ++m, ++matchIndex)
                        {
                            chr = sqlPtr[m];

                            sb.Append(chr);
                            lastChar = chr;

                            if (chr != sqlPtr[matchIndex])
                            {
                                i            = m;
                                matchedCount = 0;
                                matchIndex   = namedStringStart;
                                break;
                            }

                            if (++matchedCount == matchLength)
                            {
                                i = m;
                                // Match success.
                                goto next;
                            }
                        }
                    }

                    // If we enumerate the entire string and do not
                    // find a match, it's technically invalid syntax
                    // but that's the user's problem. They're better
                    // off getting the actual error from postgres.

                    continue;

                // Handle @@NotNamedParameter
                case '@' when nextChr == '@':
                    // Append and fast forward past next one.
                    sb.Append(chr);
                    ++i;
                    lastChar = '\0';
                    continue;

                // Handle @NamedParameter
                case '@':
                    var start = i + 1;

                    var offset = 0;
                    for (i = start; i < sql.Length; ++i)
                    {
                        if (!char.IsLetterOrDigit(sqlPtr[i]))
                        {
                            --i;
                            offset = 1;
                            break;
                        }
                    }

                    var name       = sql.Substring(start, i - start + offset);
                    var paramIndex = parameters.IndexOf(name);

                    if (paramIndex == -1)
                    {
                        throw new ArgumentOutOfRangeException(
                                  "parameterName", name,
                                  "Parameter inside query was not found inside parameter list.");
                    }

                    sb.Append('$');
                    sb.Append(paramIndex + 1);
                    lastChar = '\0';
                    continue;

                // Handle -- quotes.
                case '-' when lastChar == '-':
                    sb.Append(chr);
                    lastChar = '\0';

                    for (++i; i < sql.Length; ++i)
                    {
                        chr = sqlPtr[i];

                        sb.Append(chr);
                        lastChar = chr;

                        if (chr == '\n')
                        {
                            break;
                        }
                    }
                    continue;

                // Handle /* */ quotes.
                case '*' when lastChar == '/':
                    if (i == sql.Length - 1)
                    {
                        goto default;
                    }

                    var indexOfComment = sql.IndexOf("*/", i + 1);

                    // Really... it's invalid syntax.
                    if (indexOfComment == -1)
                    {
                        goto default;
                    }

                    // 2 is length of "*/"
                    sb.Append(sql, i, indexOfComment - i + 2);

                    i = indexOfComment + 1;
                    continue;

                case ';':
                    var singleSqlCommand = sb.ToStringTrim();
                    sb.Clear();

                    if (!string.IsNullOrWhiteSpace(singleSqlCommand))
                    {
                        queries.Add(singleSqlCommand);
                    }

                    continue;

                default:
                    sb.Append(chr);
                    lastChar = chr;
                    continue;
                }

                next :;
            }

            if (sb.Length > 0)
            {
                var singleSqlCommand = sb.ToStringTrim();

                if (!string.IsNullOrWhiteSpace(singleSqlCommand))
                {
                    queries.Add(singleSqlCommand);
                }
            }

            StringBuilderPool.Free(ref sb);
            return(queries);
        }
        private async Task QueryCore(
            bool async,
            string queryText,
            PostgresCommand command,
            CancellationToken cancellationToken)
        {
            BindParameter[] parameters     = null;
            var             parameterCount = command.Parameters.Count;

            try
            {
                if (parameterCount > 0)
                {
                    parameters = ArrayPool <BindParameter>
                                 .GetArray(parameterCount);

                    var encoding = ClientState.ClientEncoding;

                    for (var i = 0; i < parameterCount; ++i)
                    {
                        var param = command.Parameters[i].Value;

                        if (param == null)
                        {
                            parameters[i] = new BindParameter {
                                ParameterByteCount = 0,
                                Parameters         = EmptyArray <byte> .Value
                            };

                            continue;
                        }

                        // TODO: This allocation fest is terrible. Make this
                        // write directly to the memorystream instead of having
                        // intermittent buffers for everything.

                        var paramString = param.ToString();

                        var maxBytes = encoding
                                       .GetMaxByteCount(paramString.Length);

                        var paramBuffer = ArrayPool <byte> .GetArray(maxBytes);

                        var actualBytes = encoding.GetBytes(
                            paramString, 0, paramString.Length,
                            paramBuffer, 0);

                        parameters[i] = new BindParameter {
                            ParameterByteCount = actualBytes,
                            Parameters         = paramBuffer
                        };
                    }
                }

                WriteMessage(new ParseMessage {
                    Query = queryText
                });

                WriteMessage(new BindMessage {
                    PreparedStatementName       = "",
                    ResultColumnFormatCodeCount = 1,
                    ParameterCount          = (short)parameterCount,
                    Parameters              = parameters,
                    ResultColumnFormatCodes =
                        QueryResultFormat == PostgresFormatCode.Binary
                            ? _binaryFormatCode
                            : _textFormatCode
                });

                WriteMessage(new DescribeMessage {
                    StatementTargetType = StatementTargetType.Portal
                });

                WriteMessage(new ExecuteMessage {
                });

                WriteMessage(new SyncMessage {
                });

                await FlushWrites(async, cancellationToken)
                .ConfigureAwait(false);
            }
            finally
            {
                if (parameters != null)
                {
                    for (var i = 0; i < parameterCount; ++i)
                    {
                        var param = parameters[i].Parameters;
                        ArrayPool.Free(ref param);
                    }
                }
            }
        }