1  /* statement.c - the statement type
       2   *
       3   * Copyright (C) 2005-2010 Gerhard Häring <gh@ghaering.de>
       4   *
       5   * This file is part of pysqlite.
       6   *
       7   * This software is provided 'as-is', without any express or implied
       8   * warranty.  In no event will the authors be held liable for any damages
       9   * arising from the use of this software.
      10   *
      11   * Permission is granted to anyone to use this software for any purpose,
      12   * including commercial applications, and to alter it and redistribute it
      13   * freely, subject to the following restrictions:
      14   *
      15   * 1. The origin of this software must not be misrepresented; you must not
      16   *    claim that you wrote the original software. If you use this software
      17   *    in a product, an acknowledgment in the product documentation would be
      18   *    appreciated but is not required.
      19   * 2. Altered source versions must be plainly marked as such, and must not be
      20   *    misrepresented as being the original software.
      21   * 3. This notice may not be removed or altered from any source distribution.
      22   */
      23  
      24  #include "connection.h"
      25  #include "statement.h"
      26  #include "util.h"
      27  
      28  /* prototypes */
      29  static const char *lstrip_sql(const char *sql);
      30  
      31  pysqlite_Statement *
      32  pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
      33  {
      34      pysqlite_state *state = connection->state;
      35      assert(PyUnicode_Check(sql));
      36      Py_ssize_t size;
      37      const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size);
      38      if (sql_cstr == NULL) {
      39          return NULL;
      40      }
      41  
      42      sqlite3 *db = connection->db;
      43      int max_length = sqlite3_limit(db, SQLITE_LIMIT_SQL_LENGTH, -1);
      44      if (size > max_length) {
      45          PyErr_SetString(connection->DataError,
      46                          "query string is too large");
      47          return NULL;
      48      }
      49      if (strlen(sql_cstr) != (size_t)size) {
      50          PyErr_SetString(connection->ProgrammingError,
      51                          "the query contains a null character");
      52          return NULL;
      53      }
      54  
      55      sqlite3_stmt *stmt;
      56      const char *tail;
      57      int rc;
      58      Py_BEGIN_ALLOW_THREADS
      59      rc = sqlite3_prepare_v2(db, sql_cstr, (int)size + 1, &stmt, &tail);
      60      Py_END_ALLOW_THREADS
      61  
      62      if (rc != SQLITE_OK) {
      63          _pysqlite_seterror(state, db);
      64          return NULL;
      65      }
      66  
      67      if (lstrip_sql(tail) != NULL) {
      68          PyErr_SetString(connection->ProgrammingError,
      69                          "You can only execute one statement at a time.");
      70          goto error;
      71      }
      72  
      73      /* Determine if the statement is a DML statement.
      74         SELECT is the only exception. See #9924. */
      75      int is_dml = 0;
      76      const char *p = lstrip_sql(sql_cstr);
      77      if (p != NULL) {
      78          is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
      79                    || (PyOS_strnicmp(p, "update", 6) == 0)
      80                    || (PyOS_strnicmp(p, "delete", 6) == 0)
      81                    || (PyOS_strnicmp(p, "replace", 7) == 0);
      82      }
      83  
      84      pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
      85                                                 state->StatementType);
      86      if (self == NULL) {
      87          goto error;
      88      }
      89  
      90      self->st = stmt;
      91      self->is_dml = is_dml;
      92  
      93      PyObject_GC_Track(self);
      94      return self;
      95  
      96  error:
      97      (void)sqlite3_finalize(stmt);
      98      return NULL;
      99  }
     100  
     101  static void
     102  stmt_dealloc(pysqlite_Statement *self)
     103  {
     104      PyTypeObject *tp = Py_TYPE(self);
     105      PyObject_GC_UnTrack(self);
     106      if (self->st) {
     107          Py_BEGIN_ALLOW_THREADS
     108          sqlite3_finalize(self->st);
     109          Py_END_ALLOW_THREADS
     110          self->st = 0;
     111      }
     112      tp->tp_free(self);
     113      Py_DECREF(tp);
     114  }
     115  
     116  static int
     117  stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
     118  {
     119      Py_VISIT(Py_TYPE(self));
     120      return 0;
     121  }
     122  
     123  /*
     124   * Strip leading whitespace and comments from incoming SQL (null terminated C
     125   * string) and return a pointer to the first non-whitespace, non-comment
     126   * character.
     127   *
     128   * This is used to check if somebody tries to execute more than one SQL query
     129   * with one execute()/executemany() command, which the DB-API don't allow.
     130   *
     131   * It is also used to harden DML query detection.
     132   */
     133  static inline const char *
     134  lstrip_sql(const char *sql)
     135  {
     136      // This loop is borrowed from the SQLite source code.
     137      for (const char *pos = sql; *pos; pos++) {
     138          switch (*pos) {
     139              case ' ':
     140              case '\t':
     141              case '\f':
     142              case '\n':
     143              case '\r':
     144                  // Skip whitespace.
     145                  break;
     146              case '-':
     147                  // Skip line comments.
     148                  if (pos[1] == '-') {
     149                      pos += 2;
     150                      while (pos[0] && pos[0] != '\n') {
     151                          pos++;
     152                      }
     153                      if (pos[0] == '\0') {
     154                          return NULL;
     155                      }
     156                      continue;
     157                  }
     158                  return pos;
     159              case '/':
     160                  // Skip C style comments.
     161                  if (pos[1] == '*') {
     162                      pos += 2;
     163                      while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
     164                          pos++;
     165                      }
     166                      if (pos[0] == '\0') {
     167                          return NULL;
     168                      }
     169                      pos++;
     170                      continue;
     171                  }
     172                  return pos;
     173              default:
     174                  return pos;
     175          }
     176      }
     177  
     178      return NULL;
     179  }
     180  
     181  static PyType_Slot stmt_slots[] = {
     182      {Py_tp_dealloc, stmt_dealloc},
     183      {Py_tp_traverse, stmt_traverse},
     184      {0, NULL},
     185  };
     186  
     187  static PyType_Spec stmt_spec = {
     188      .name = MODULE_NAME ".Statement",
     189      .basicsize = sizeof(pysqlite_Statement),
     190      .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
     191                Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_DISALLOW_INSTANTIATION),
     192      .slots = stmt_slots,
     193  };
     194  
     195  int
     196  pysqlite_statement_setup_types(PyObject *module)
     197  {
     198      PyObject *type = PyType_FromModuleAndSpec(module, &stmt_spec, NULL);
     199      if (type == NULL) {
     200          return -1;
     201      }
     202      pysqlite_state *state = pysqlite_get_state(module);
     203      state->StatementType = (PyTypeObject *)type;
     204      return 0;
     205  }