diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index 2cab6ea1a..077e62493 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -805,7 +805,7 @@ Query Types .. py:class:: Query() - The parent class from which all other query classes are drived. While you + The parent class from which all other query classes are derived. While you will not deal with :py:class:`Query` directly in your code, it implements some methods that are common across all query types. @@ -846,7 +846,8 @@ Query Types :param model: the model to join on. there must be a :py:class:`ForeignKeyField` between the current ``query context`` and the model passed in. :param join_type: allows the type of ``JOIN`` used to be specified explicitly, - one of ``JOIN.INNER``, ``JOIN.LEFT_OUTER``, ``JOIN.FULL`` + one of ``JOIN.INNER``, ``JOIN.LEFT_OUTER``, ``JOIN.FULL``, ``JOIN.RIGHT_OUTER``, + or ``JOIN.CROSS``. :param on: if multiple foreign keys exist between two models, this parameter is the ForeignKeyField to join on. :rtype: a :py:class:`Query` instance @@ -1154,6 +1155,22 @@ Query Types ``True`` then the database will raise an ``OperationalError`` if it cannot obtain the lock. + .. py:method:: with_lock([lock_type='UPDATE']) + + :rtype: :py:class:`SelectQuery` + + Indicates that this query shoudl lock rows. A more generic version of + the :py:meth:`~SelectQuery.for_update` method. + + Example: + + .. code-block:: python + + # SELECT * FROM some_model FOR KEY SHARE NOWAIT; + SomeModel.select().with_lock('KEY SHARE NOWAIT') + + .. note:: You do not need to include the word *FOR*. + .. py:method:: naive() :rtype: :py:class:`SelectQuery` @@ -2130,10 +2147,34 @@ Database and its subclasses the :py:meth:`~Database.set_autocommit` and :py:meth:`Database.get_autocommit` methods. - .. py:method:: begin() + .. py:method:: begin([lock_type=None]) Initiate a new transaction. By default **not** implemented as this is not - part of the DB-API 2.0, but provided for API compatibility. + part of the DB-API 2.0, but provided for API compatibility and to allow + SQLite users to specify the isolation level when beginning transactions. + + For SQLite users, the valid isolation levels for ``lock_type`` are: + + * ``exclusive`` + * ``immediate`` + * ``deferred`` + + Example usage: + + .. code-block:: python + + # Calling transaction() in turn calls begin('exclusive'). + with db.transaction('exclusive'): + # No other readers or writers allowed while this is active. + (Account + .update(Account.balance=Account.balance - 100) + .where(Account.id == from_acct) + .execute()) + + (Account + .update(Account.balance=Account.balance + 100) + .where(Account.id == to_acct) + .execute()) .. py:method:: commit() @@ -2268,11 +2309,16 @@ Database and its subclasses db.drop_tables([User, Tweet, Something], safe=True) - .. py:method:: atomic() + .. py:method:: atomic([transaction_type=None]) Execute statements in either a transaction or a savepoint. The outer-most call to *atomic* will use a transaction, and any subsequent nested calls will use savepoints. + :param str transaction_type: Specify isolation level. This parameter + only has effect on **SQLite databases**, and furthermore, only + affects the outer-most call to :py:meth:`~Database.atomic`. For + more information, see :py:meth:`~Database.transaction`. + ``atomic`` can be used as either a context manager or a decorator. .. note:: @@ -2301,7 +2347,7 @@ Database and its subclasses # This function will execute in a transaction/savepoint. return User.create(username=username) - .. py:method:: transaction() + .. py:method:: transaction([transaction_type=None]) Execute statements in a transaction using either a context manager or decorator. If an error is raised inside the wrapped block, the transaction will be rolled @@ -2311,6 +2357,8 @@ Database and its subclasses will keep a stack and only commit when it reaches the end of the outermost function / block. + :param str transaction_type: Specify isolation level, **SQLite only**. + Context manager example code: .. code-block:: python @@ -2338,6 +2386,29 @@ Database and its subclasses to_acct.pay(amt) return amt + SQLite users can specify the isolation level by specifying one of the + following values for ``transaction_type``: + + * ``exclusive`` + * ``immediate`` + * ``deferred`` + + Example usage: + + .. code-block:: python + + with db.transaction('exclusive'): + # No other readers or writers allowed while this is active. + (Account + .update(Account.balance=Account.balance - 100) + .where(Account.id == from_acct) + .execute()) + + (Account + .update(Account.balance=Account.balance + 100) + .where(Account.id == to_acct) + .execute()) + .. py:method:: commit_on_success(func) .. note:: Use :py:meth:`~Database.atomic` or :py:meth:`~Database.transaction` instead. @@ -2651,12 +2722,14 @@ Misc ``fn.Stddev(Employee.salary).alias('sdv')`` ``Stddev(t1."salary") AS sdv`` ============================================ ============================================ - .. py:method:: over([partition_by=None[, order_by=None[, window=None]]]) + .. py:method:: over([partition_by=None[, order_by=None[, start=None[, end=None[, window=None]]]]]) Basic support for SQL window functions. :param list partition_by: List of :py:class:`Node` instances to partition by. :param list order_by: List of :py:class:`Node` instances to use for ordering. + :param start: The start of the *frame* of the window query. + :param end: The end of the *frame* of the window query. :param Window window: A :py:class:`Window` instance to use for this aggregate. Examples: @@ -2701,6 +2774,15 @@ Misc .window(window) # Need to include our Window here. .order_by(PageView.timestamp)) + # Get the list of times along with the last time. + query = (Times + .select( + Times.time, + fn.LAST_VALUE(Times.time).over( + order_by=[Times.time], + start=Window.preceding(), + end=Window.following()))) + .. py:class:: SQL(sql, *params) Add fragments of SQL to a peewee query. For example you might want to reference @@ -2721,12 +2803,14 @@ Misc # Sort the users by number of tweets. query = query.order_by(SQL('ct DESC')) -.. py:class:: Window([partition_by=None[, order_by=None]]) +.. py:class:: Window([partition_by=None[, order_by=None[, start=None[, end=None]]]]) Create a ``WINDOW`` definition. :param list partition_by: List of :py:class:`Node` instances to partition by. :param list order_by: List of :py:class:`Node` instances to use for ordering. + :param start: The start of the *frame* of the window query. + :param end: The end of the *frame* of the window query. Examples: @@ -2743,6 +2827,18 @@ Misc .window(window) .order_by(Employee.name)) + .. py:staticmethod:: preceding([value=None]) + + Return an expression appropriate for passing in to the ``start`` or + ``end`` clause of a :py:class:`Window` object. If ``value`` is not + provided, then it will be ``UNBOUNDED PRECEDING``. + + .. py:staticmethod:: following([value=None]) + + Return an expression appropriate for passing in to the ``start`` or + ``end`` clause of a :py:class:`Window` object. If ``value`` is not + provided, then it will be ``UNBOUNDED FOLLOWING``. + .. py:class:: DeferredRelation() Used to reference a not-yet-created model class. Stands in as a placeholder for the related model of a foreign key. Useful for circular references. diff --git a/docs/peewee/example.rst b/docs/peewee/example.rst index b8a8d1c8f..696eae5f3 100644 --- a/docs/peewee/example.rst +++ b/docs/peewee/example.rst @@ -219,7 +219,7 @@ These methods are similar in their implementation but with an important differen Creating new objects ^^^^^^^^^^^^^^^^^^^^ -When a new user wants to join the site we need to make sure the username is available, and if so, create a new *User* record. Looking at the *join()* view, we can that our application attempts to create the User using :py:meth:`Model.create`. We defined the *User.username* field with a unique constraint, so if the username is taken the database will raise an ``IntegrityError``. +When a new user wants to join the site we need to make sure the username is available, and if so, create a new *User* record. Looking at the *join()* view, we can see that our application attempts to create the User using :py:meth:`Model.create`. We defined the *User.username* field with a unique constraint, so if the username is taken the database will raise an ``IntegrityError``. .. code-block:: python diff --git a/docs/peewee/hacks.rst b/docs/peewee/hacks.rst index 39db585ac..e23e6ca24 100644 --- a/docs/peewee/hacks.rst +++ b/docs/peewee/hacks.rst @@ -5,6 +5,114 @@ Hacks Collected hacks using peewee. Have a cool hack you'd like to share? Open `an issue on GitHub `_ or `contact me `_. +.. _optimistic_locking: + +Optimistic Locking +------------------ + +Optimistic locking is useful in situations where you might ordinarily use a +*SELECT FOR UPDATE* (or in SQLite, *BEGIN IMMEDIATE*). For example, you might +fetch a user record from the database, make some modifications, then save the +modified user record. Typically this scenario would require us to lock the user +record for the duration of the transaction, from the moment we select it, to +the moment we save our changes. + +In optimistic locking, on the other hand, we do *not* acquire any lock and +instead rely on an internal *version* column in the row we're modifying. At +read time, we see what version the row is currently at, and on save, we ensure +that the update takes place only if the version is the same as the one we +initially read. If the version is higher, then some other process must have +snuck in and changed the row -- to save our modified version could result in +the loss of important changes. + +It's quite simple to implement optimistic locking in Peewee, here is a base +class that you can use as a starting point: + +.. code-block:: python + + from peewee import * + + class BaseVersionedModel(Model): + version = IntegerField(default=1, index=True) + + def save_optimistic(self): + if not self.id: + # This is a new record, so the default logic is to perform an + # INSERT. Ideally your model would also have a unique + # constraint that made it impossible for two INSERTs to happen + # at the same time. + return self.save() + + # Update any data that has changed and bump the version counter. + field_data = dict(self._data) + current_version = field_data.pop('version', 1) + field_data = self._prune_fields(field_data, self.dirty_fields) + if not field_data: + raise ValueError('No changes have been made.') + + ModelClass = type(self) + field_data['version'] = ModelClass.version + 1 # Atomic increment. + + query = ModelClass.update(**field_data).where( + (ModelClass.version == current_version) & + (ModelClass.id == self.id)) + if query.execute() == 0: + # No rows were updated, indicating another process has saved + # a new version. How you handle this situation is up to you, + # but for simplicity I'm just raising an exception. + raise ConflictDetectedException() + else: + # Increment local version to match what is now in the db. + self.version += 1 + return True + +Here's an example of how this works. Let's assume we have the following model +definition. Note that there's a unique constraint on the username -- this is +important as it provides a way to prevent double-inserts. + +.. code-block:: python + + class User(BaseVersionedModel): + username = CharField(unique=True) + favorite_animal = CharField() + +Example: + +.. code-block:: pycon + + >>> u = User(username='charlie', favorite_animal='cat') + >>> u.save_optimistic() + True + + >>> u.version + 1 + + >>> u.save_optimistic() + Traceback (most recent call last): + File "", line 1, in + File "x.py", line 18, in save_optimistic + raise ValueError('No changes have been made.') + ValueError: No changes have been made. + + >>> u.favorite_animal = 'kitten' + >>> u.save_optimistic() + True + + # Simulate a separate thread coming in and updating the model. + >>> u2 = User.get(User.username == 'charlie') + >>> u2.favorite_animal = 'macaw' + >>> u2.save_optimistic() + True + + # Now, attempt to change and re-save the original instance: + >>> u.favorite_animal = 'little parrot' + >>> u.save_optimistic() + Traceback (most recent call last): + File "", line 1, in + File "x.py", line 30, in save_optimistic + raise ConflictDetectedException() + ConflictDetectedException: current version is out of sync + .. _top_item_per_group: Top object per group diff --git a/docs/peewee/installation.rst b/docs/peewee/installation.rst index 775590fb6..e73df67b1 100644 --- a/docs/peewee/installation.rst +++ b/docs/peewee/installation.rst @@ -64,3 +64,34 @@ extension tests are not run. To view the available test runner options, use: .. code-block:: console python runtests.py --help + +Optional dependencies +--------------------- + +.. note:: + To use Peewee, you typically won't need anything outside the standard + library, since most Python distributions are compiled with SQLite support. + You can test by running ``import sqlite3`` in the Python console. If you + wish to use another database, there are many DB-API 2.0-compatible drivers + out there, such as ``pymysql`` or ``psycopg2`` for MySQL and Postgres + respectively. + +* `Cython `_: used for various speedups. Can give a big + boost to certain operations, particularly if you use SQLite. +* `apsw `_: an optional 3rd-party SQLite + binding offering greater performance and much, much saner semantics than the + standard library ``pysqlite``. Use with :py:class:`APSWDatabase`. +* `pycrypto `_ is used for the + :py:class:`AESEncryptedField`. +* ``bcrypt`` module is used for the :py:class:`PasswordField`. +* `vtfunc ` is used to provide some + table-valued functions for Sqlite as part of the ``sqlite_udf`` extensions + module. +* `gevent `_ is an optional dependency for + :py:class:`SqliteQueueDatabase` (though it works with ``threading`` just + fine). +* `BerkeleyDB `_ can + be compiled with a SQLite frontend, which works with Peewee. Compiling can be + tricky so `here are instructions `_. +* Lastly, if you use the *Flask* or *Django* frameworks, there are helper + extension modules available. diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index d65f47fd8..b0fff57e0 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -3,7 +3,7 @@ Playhouse, extensions to Peewee =============================== -Peewee comes with numerous extrension modules which are collected under the ``playhouse`` namespace. Despite the silly name, there are some very useful extensions, particularly those that expose vendor-specific database features like the :ref:`sqlite_ext` and :ref:`postgres_ext` extensions. +Peewee comes with numerous extension modules which are collected under the ``playhouse`` namespace. Despite the silly name, there are some very useful extensions, particularly those that expose vendor-specific database features like the :ref:`sqlite_ext` and :ref:`postgres_ext` extensions. Below you will find a loosely organized listing of the various modules that make up the ``playhouse``. @@ -171,31 +171,6 @@ sqlite_ext API notes Unload the given SQLite extension. - .. py:method:: granular_transaction([lock_type='deferred']) - - With the ``granular_transaction`` helper, you can specify the isolation level - for an individual transaction. The valid options are: - - * ``exclusive`` - * ``immediate`` - * ``deferred`` - - Example usage: - - .. code-block:: python - - with db.granular_transaction('exclusive'): - # no other readers or writers! - (Account - .update(Account.balance=Account.balance - 100) - .where(Account.id == from_acct) - .execute()) - - (Account - .update(Account.balance=Account.balance + 100) - .where(Account.id == to_acct) - .execute()) - .. py:class:: VirtualModel @@ -885,48 +860,66 @@ sqlite_ext API notes .. note:: For an in-depth discussion of the SQLite transitive closure extension, check out this blog post, `Querying Tree Structures in SQLite using Python and the Transitive Closure Extension `_. + .. _sqliteq: SqliteQ ------- -The ``playhouse.sqliteq`` module provides a subclass of :py:class:`SqliteExtDatabase`, -that will serialize concurrent access to a SQLite database. The :py:class:`SqliteQueueDatabase` -is meant to be used as a drop-in replacement, and all the magic happens below -the public APIs. This should hopefully make it very easy to integrate into an -existing application. +The ``playhouse.sqliteq`` module provides a subclass of +:py:class:`SqliteExtDatabase`, that will serialize concurrent writes to a +SQLite database. :py:class:`SqliteQueueDatabase` can be used as a drop-in +replacement for the regular :py:class:`SqliteDatabase` if you want simple +**read and write** access to a SQLite database from **multiple threads**. + +SQLite only allows one connection to write to the database at any given time. +As a result, if you have a multi-threaded application (like a web-server, for +example) that needs to write to the database, you may see occasional errors +when one or more of the threads attempting to write cannot acquire the lock. + +:py:class:`SqliteQueueDatabase` is designed to simplify things by sending all +write queries through a single, long-lived connection. The benefit is that you +get the appearance of multiple threads writing to the database without +conflicts or timeouts. The downside, however, is that you cannot issue +write transactions that encompass multiple queries -- all writes run in +autocommit mode, essentially. .. note:: - This is a new module and should be considered alpha-quality software. + The module gets its name from the fact that all write queries get put into + a thread-safe queue. A single worker thread listens to the queue and + executes all queries that are sent to it. -Explanation -^^^^^^^^^^^ +Transactions +^^^^^^^^^^^^ + +Because all queries are serialized and executed by a single worker thread, it +is possible for transactional SQL from separate threads to be executed +out-of-order. In the example below, the transaction started by thread "B" is +rolled back by thread "A" (with bad consequences!): + +* Thread A: UPDATE transplants SET organ='liver', ...; +* Thread B: BEGIN TRANSACTION; +* Thread B: UPDATE life_support_system SET timer += 60 ...; +* Thread A: ROLLBACK; -- Oh no.... + +Since there is a potential for queries from separate transactions to be +interleaved, the :py:meth:`~SqliteQueueDatabase.transaction` and +:py:meth:`~SqliteQueueDatabase.atomic` methods are disabled on :py:class:`SqliteQueueDatabase`. -It is important to understand the way SQLite handles concurrency when using -`write-ahead logging `_, but in the simpleset -terms only one connection can write to the database at a time **and** any -number of other connections can read while the database is being written to. -Or, in other words, readers don't block the writer, and the writer doesn't -block the readers. - -An example that comes to my mind is a web application, which handles each -request in a separate thread/greenlet. If the application is particularly busy -and there are multiple connections open at a given point in time, you can end -up in a bad situation quickly because SQLite limits you to one writer. This -typically manifests as ``OperationalError: database is locked`` exceptions. - -Due to the global interpreter lock, however, Python appears single-threaded to -other applications (only one thread can run Python code at a time, per -interpreter process). It follows then, that even though multiple threads are -attempting to access the SQLite database, SQLite only sees one thread accessing -the database at any point in time. - -So, what we can do is create a single *worker* thread that is responsible for -all writes to the database, and have our other request-handling threads -hand-off their writes. In this way, we'll have our cake and eat it, too -- our -Python application can queue-up writes from as many threads as it wants and we -should hardly notice the performance hit that comes from pushing all database -accesses through a single thread. +For cases when you wish to temporarily write to the database from a different +thread, you can use the :py:meth:`~SqliteQueueDatabase.pause` and +:py:meth:`~SqliteQueueDatabase.unpause` methods. These methods block the +caller until the writer thread is finished with its current workload. The +writer then disconnects and the caller takes over until ``unpause`` is called. + +The :py:meth:`~SqliteQueueDatabase.stop`, :py:meth:`~SqliteQueueDatabase.start`, +and :py:meth:`~SqliteQueueDatabase.is_stopped` methods can also be used to +control the writer thread. + +.. note:: + Take a look at SQLite's `isolation `_ + documentation for more information about how SQLite handles concurrent + connections. Code sample ^^^^^^^^^^^ @@ -944,19 +937,15 @@ creation, and locking. db = SqliteQueueDatabase( 'my_app.db', - use_gevent=False, # Use standard library "threading" module. - autostart=False, # Do not automatically start the workers. - queue_max_size=1024, # Max. # of pending writes that can accumulate. - readers=4, # Size of reader thread-pool - these handle non-writes. + use_gevent=False, # Use the standard library "threading" module. + autostart=False, # The worker thread now must be started manually. + queue_max_size=64, # Max. # of pending writes that can accumulate. results_timeout=5.0) # Max. time to wait for query to be executed. If ``autostart=False``, as in the above example, you will need to call :py:meth:`~SqliteQueueDatabase.start` to bring up the worker threads that will -do the actual query execution. Additionally, because the connections are -managed by the database class itself, you do not need to call -:py:meth:`~Database.connect` or :py:meth:`~Database.close` at any point in your -application. +do the actual write query execution. .. code-block:: python @@ -964,10 +953,14 @@ application. def _start_worker_threads(): db.start() -When your application is ready to terminate, use the -:py:meth:`~SqliteQueueDatabase.stop` method to shut down the worker threads. -If there was a backlog of work, then this method will block until all pending -work is finished (though no new work is allowed). +If you plan on performing SELECT queries or generally wanting to access the +database, you will need to call :py:meth:`~Database.connect` and +:py:meth:`~Database.close` as you would with any other database instance. + +When your application is ready to terminate, use the :py:meth:`~SqliteQueueDatabase.stop` +method to shut down the worker thread. If there was a backlog of work, then +this method will block until all pending work is finished (though no new work +is allowed). .. code-block:: python @@ -979,8 +972,7 @@ work is finished (though no new work is allowed). Lastly, the :py:meth:`~SqliteQueueDatabase.is_stopped` method can be used to -determine whether the database workers are up and running. - +determine whether the database writer is up and running. .. _sqlite_udf: @@ -1100,13 +1092,6 @@ apsw_ext API notes :param string database: filename of sqlite database :param connect_kwargs: keyword arguments passed to apsw when opening a connection - .. py:method:: transaction([lock_type='deferred']) - - Functions just like the :py:meth:`Database.transaction` context manager, - but accepts an additional parameter specifying the type of lock to use. - - :param string lock_type: type of lock to use when opening a new transaction - .. py:method:: register_module(mod_name, mod_inst) Provides a way of globally registering a module. For more information, @@ -3983,13 +3968,14 @@ That's it! If you would like finer-grained control over the pool of connections, Pool APIs ^^^^^^^^^ -.. py:class:: PooledDatabase(database[, max_connections=20[, stale_timeout=None[, **kwargs]]]) +.. py:class:: PooledDatabase(database[, max_connections=20[, stale_timeout=None[, timeout=None[, **kwargs]]]]) Mixin class intended to be used with a subclass of :py:class:`Database`. :param str database: The name of the database or database file. :param int max_connections: Maximum number of connections. Provide ``None`` for unlimited. :param int stale_timeout: Number of seconds to allow connections to be used. + :param int timeout: Number of seconds block when pool is full. By default peewee does not block when the pool is full but simply throws an exception. To block indefinitely set this value to ``0``. :param kwargs: Arbitrary keyword arguments passed to database class. .. note:: Connections will not be closed exactly when they exceed their `stale_timeout`. Instead, stale connections are only closed when a new connection is requested. @@ -4254,10 +4240,18 @@ Flask Utils The ``playhouse.flask_utils`` module contains several helpers for integrating peewee with the `Flask `_ web framework. -Database wrapper +Database Wrapper ^^^^^^^^^^^^^^^^ -The :py:class:`FlaskDB` class provides a convenient way to configure a peewee :py:class:`Database` instance using Flask app configuration. The :py:class:`FlaskDB` wrapper will also automatically set up request setup and teardown handlers to ensure your connections are managed correctly. +The :py:class:`FlaskDB` class is a wrapper for configuring and referencing a +Peewee database from within a Flask application. Don't let it's name fool you: +it is **not the same thing as a peewee database**. ``FlaskDB`` is designed to +remove the following boilerplate from your flask app: + +* Dynamically create a Peewee database instance based on app config data. +* Create a base class from which all your application's models will descend. +* Register hooks at the start and end of a request to handle opening and + closing a database connection. Basic usage: @@ -4273,28 +4267,43 @@ Basic usage: app = Flask(__name__) app.config.from_object(__name__) - database = FlaskDB(app) + db_wrapper = FlaskDB(app) - class User(database.Model): + class User(db_wrapper.Model): username = CharField(unique=True) - class Tweet(database.Model): + class Tweet(db_wrapper.Model): user = ForeignKeyField(User, related_name='tweets') content = TextField() timestamp = DateTimeField(default=datetime.datetime.now) The above code example will create and instantiate a peewee :py:class:`PostgresqlDatabase` specified by the given database URL. Request hooks will be configured to establish a connection when a request is received, and automatically close the connection when the response is sent. Lastly, the :py:class:`FlaskDB` class exposes a :py:attr:`FlaskDB.Model` property which can be used as a base for your application's models. -.. note:: The underlying peewee database can be accessed using the ``FlaskDB.database`` attribute. +Here is how you can access the wrapped Peewee database instance that is +configured for you by the ``FlaskDB`` wrapper: + +.. code-block:: python + + # Obtain a reference to the Peewee database instance. + peewee_db = db_wrapper.database -If you prefer, you can also pass the database value directly into the ``FlaskDB`` object: + @app.route('/transfer-funds/', methods=['POST']) + def transfer_funds(): + with peewee_db.atomic(): + # ... + + return jsonify({'transfer-id': xid}) + +.. note:: The actual peewee database can be accessed using the ``FlaskDB.database`` attribute. + +Here is another way to configure a Peewee database using ``FlaskDB``: .. code-block:: python app = Flask(__name__) - database = FlaskDB(app, 'sqlite:///my_app.db') + db_wrapper = FlaskDB(app, 'sqlite:///my_app.db') -While the above examples show using a database URL, for more advanced usages you can specify a dictionary of configuration options or simply pass in a peewee :py:class:`Database` instance: +While the above examples show using a database URL, for more advanced usages you can specify a dictionary of configuration options, or simply pass in a peewee :py:class:`Database` instance: .. code-block:: python @@ -4309,7 +4318,8 @@ While the above examples show using a database URL, for more advanced usages you app = Flask(__name__) app.config.from_object(__name__) - database = FlaskDB(app) + wrapper = FlaskDB(app) + pooled_postgres_db = wrapper.database Using a peewee :py:class:`Database` object: @@ -4317,7 +4327,8 @@ Using a peewee :py:class:`Database` object: peewee_db = PostgresqlExtDatabase('my_app') app = Flask(__name__) - flask_db = FlaskDB(app, peewee_db) + db_wrapper = FlaskDB(app, peewee_db) + Database with Application Factory ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -4328,18 +4339,18 @@ Using as a factory: .. code-block:: python - database = FlaskDB() + db_wrapper = FlaskDB() # Even though the database is not yet initialized, you can still use the # `Model` property to create model classes. - class User(database.Model): + class User(db_wrapper.Model): username = CharField(unique=True) def create_app(): app = Flask(__name__) app.config['DATABASE'] = 'sqlite:////home/code/apps/my-database.db' - database.init_app(app) + db_wrapper.init_app(app) return app Query utilities diff --git a/docs/peewee/querying.rst b/docs/peewee/querying.rst index ec590b9c0..cd59c5562 100644 --- a/docs/peewee/querying.rst +++ b/docs/peewee/querying.rst @@ -1180,7 +1180,7 @@ Use the :py:meth:`~Query.join` method to *JOIN* additional tables. When a foreig >>> my_tweets = Tweet.select().join(User).where(User.username == 'charlie') -By default peewee will use an *INNER* join, but you can use *LEFT OUTER* or *FULL* joins as well: +By default peewee will use an *INNER* join, but you can use *LEFT OUTER*, *RIGHT OUTER*, *FULL*, or *CROSS* joins as well: .. code-block:: python diff --git a/docs/peewee/quickstart.rst b/docs/peewee/quickstart.rst index 7e87b4dce..5c21a587a 100644 --- a/docs/peewee/quickstart.rst +++ b/docs/peewee/quickstart.rst @@ -31,7 +31,7 @@ Field instance Column on a table Model instance Row in a database table ================= ================================= -When starting to a project with peewee, it's typically best to begin with your data model, by defining one or more :py:class:`Model` classes: +When starting a project with peewee, it's typically best to begin with your data model, by defining one or more :py:class:`Model` classes: .. code-block:: python diff --git a/docs/peewee/transactions.rst b/docs/peewee/transactions.rst index 66380edfe..e37888eb5 100644 --- a/docs/peewee/transactions.rst +++ b/docs/peewee/transactions.rst @@ -7,7 +7,53 @@ Peewee provides several interfaces for working with transactions. The most gener If an exception occurs in a wrapped block, the current transaction/savepoint will be rolled back. Otherwise the statements will be committed at the end of the wrapped block. -:py:meth:`~Database.atomic` can be used as either a **context manager** or a **decorator**. +.. note:: + While inside a block wrapped by the :py:meth:`~Database.atomic` context + manager, you can explicitly rollback or commit at any point by calling + :py:meth:`Transaction.rollback` or :py:meth:`Transaction.commit`. When you + do this inside a wrapped block of code, a new transaction will be started + automatically. + + Consider this code: + + .. code-block:: python + + db.begin() # Open a new transaction. + try: + save_some_objects() + except ErrorSavingData: + db.rollback() # Uh-oh! Let's roll-back any partial changes. + error_saving = True + + create_report(error_saving=error_saving) + db.commit() # What happens here?? + + If the ``ErrorSavingData`` exception gets raised, we call rollback, but + because we are not using the ``~Database.atomic`` context manager, **no new + transaction is begun**. The call to ``commit()`` will fail because no + transaction is active! + + On the other hand, consider this: + + .. code-block:: python + + with db.atomic() as transaction: # Opens new transaction. + try: + save_some_objects() + except ErrorSavingData: + # Because this block of code is wrapped with "atomic", a + # new transaction will begin automatically after the call + # to rollback(). + db.rollback() + error_saving = True + + create_report(error_saving=error_saving) + # Note: no need to call commit. Since this marks the end of the + # wrapped block of code, the `atomic` context manager will + # automatically call commit for us. + +.. note:: + :py:meth:`~Database.atomic` can be used as either a **context manager** or a **decorator**. Context manager --------------- diff --git a/examples/blog/app.py b/examples/blog/app.py index 5dc747309..f3a347c6e 100644 --- a/examples/blog/app.py +++ b/examples/blog/app.py @@ -93,8 +93,11 @@ def update_search_index(self): # Create a row in the FTSEntry table with the post content. This will # allow us to use SQLite's awesome full-text search extension to # search our entries. + query = (FTSEntry + .select(FTSEntry.docid, FTSEntry.entry_id) + .where(FTSEntry.entry_id == self.id)) try: - fts_entry = FTSEntry.get(FTSEntry.entry_id == self.id) + fts_entry = query.get() except FTSEntry.DoesNotExist: fts_entry = FTSEntry(entry_id=self.id) force_insert = True @@ -190,23 +193,34 @@ def index(): search=search_query, check_bounds=False) +def _create_or_edit(entry, template): + if request.method == 'POST': + entry.title = request.form.get('title') or '' + entry.content = request.form.get('content') or '' + entry.published = request.form.get('published') or False + if not (entry.title and entry.content): + flash('Title and Content are required.', 'danger') + else: + # Wrap the call to save in a transaction so we can roll it back + # cleanly in the event of an integrity error. + try: + with database.atomic(): + entry.save() + except IntegrityError: + flash('Error: this title is already in use.', 'danger') + else: + flash('Entry saved successfully.', 'success') + if entry.published: + return redirect(url_for('detail', slug=entry.slug)) + else: + return redirect(url_for('edit', slug=entry.slug)) + + return render_template(template, entry=entry) + @app.route('/create/', methods=['GET', 'POST']) @login_required def create(): - if request.method == 'POST': - if request.form.get('title') and request.form.get('content'): - entry = Entry.create( - title=request.form['title'], - content=request.form['content'], - published=request.form.get('published') or False) - flash('Entry created successfully.', 'success') - if entry.published: - return redirect(url_for('detail', slug=entry.slug)) - else: - return redirect(url_for('edit', slug=entry.slug)) - else: - flash('Title and Content are required.', 'danger') - return render_template('create.html') + return _create_or_edit(Entry(title='', content=''), 'create.html') @app.route('/drafts/') @login_required @@ -227,22 +241,7 @@ def detail(slug): @login_required def edit(slug): entry = get_object_or_404(Entry, Entry.slug == slug) - if request.method == 'POST': - if request.form.get('title') and request.form.get('content'): - entry.title = request.form['title'] - entry.content = request.form['content'] - entry.published = request.form.get('published') or False - entry.save() - - flash('Entry saved successfully.', 'success') - if entry.published: - return redirect(url_for('detail', slug=entry.slug)) - else: - return redirect(url_for('edit', slug=entry.slug)) - else: - flash('Title and Content are required.', 'danger') - - return render_template('edit.html', entry=entry) + return _create_or_edit(entry, 'edit.html') @app.template_filter('clean_querystring') def clean_querystring(request_args, *keys_to_remove, **new_values): diff --git a/examples/blog/templates/create.html b/examples/blog/templates/create.html index 5e58779dc..bd7081ca6 100644 --- a/examples/blog/templates/create.html +++ b/examples/blog/templates/create.html @@ -5,31 +5,31 @@ {% block content_title %}Create entry{% endblock %} {% block content %} -
+
- +
- +
- + Cancel
diff --git a/examples/blog/templates/edit.html b/examples/blog/templates/edit.html index 5ff394fdc..1ffa59bba 100644 --- a/examples/blog/templates/edit.html +++ b/examples/blog/templates/edit.html @@ -1,37 +1,9 @@ -{% extends "base.html" %} +{% extends "create.html" %} {% block title %}Edit entry{% endblock %} {% block content_title %}Edit entry{% endblock %} -{% block content %} - -
- -
- -
-
-
- -
- -
-
-
-
-
- -
-
-
-
-
- - Cancel -
-
-
-{% endblock %} +{% block form_action %}{{ url_for('edit', slug=entry.slug) }}{% endblock %} + +{% block save_button %}Save changes{% endblock %} diff --git a/peewee.py b/peewee.py index a3db8a3b0..941b33f51 100644 --- a/peewee.py +++ b/peewee.py @@ -41,7 +41,7 @@ from functools import wraps from inspect import isclass -__version__ = '2.8.5' +__version__ = '2.8.6' __all__ = [ 'BareField', 'BigIntegerField', @@ -91,6 +91,7 @@ 'TextField', 'TimeField', 'TimestampField', + 'Tuple', 'Using', 'UUIDField', 'Window', @@ -145,8 +146,8 @@ def print_(s): raise RuntimeError('Unsupported python version.') if PY26: - _M = 10**6 - total_seconds = lambda t: (t.microseconds + 0.0 + (t.seconds + t.days * 24 * 3600) * _M) / _M + _D, _M = 24 * 3600., 10**6 + total_seconds = lambda t: (t.microseconds+(t.seconds+t.days*_D)*_M)/_M else: total_seconds = lambda t: t.total_seconds() @@ -206,6 +207,9 @@ def dfs(model): for foreign_key in model._meta.reverse_rel.values(): dfs(foreign_key.model_class) ordering.append(model) # parent will follow descendants + if model._meta.depends_on: + for dependency in model._meta.depends_on: + dfs(dependency) # Order models by name and table initially to guarantee total ordering. names = lambda m: (m._meta.name, m._meta.db_table) for m in sorted(models, key=names, reverse=True): @@ -305,6 +309,8 @@ class attrdict(dict): def __getattr__(self, attr): return self[attr] +SENTINEL = object() + # Operators used in binary expressions. OP = attrdict( AND='and', @@ -339,6 +345,7 @@ def __getattr__(self, attr): LEFT_OUTER='LEFT OUTER', RIGHT_OUTER='RIGHT OUTER', FULL='FULL', + CROSS='CROSS', ) JOIN_INNER = JOIN.INNER JOIN_LEFT_OUTER = JOIN.LEFT_OUTER @@ -400,7 +407,7 @@ class Proxy(object): Proxy class useful for situations when you wish to defer the initialization of an object. """ - __slots__ = ['obj', '_callbacks'] + __slots__ = ('obj', '_callbacks') def __init__(self): self._callbacks = [] @@ -638,12 +645,18 @@ def clone_base(self): res._coerce = self._coerce return res - def over(self, partition_by=None, order_by=None, window=None): + def over(self, partition_by=None, order_by=None, start=None, end=None, + window=None): if isinstance(partition_by, Window) and window is None: window = partition_by + if start is not None and not isinstance(start, SQL): + start = SQL(*start) + if end is not None and not isinstance(end, SQL): + end = SQL(*end) + if window is None: - sql = Window( - partition_by=partition_by, order_by=order_by).__sql__() + sql = Window(partition_by=partition_by, order_by=order_by, + start=start, end=end).__sql__() else: sql = SQL(window._alias) return Clause(self, SQL('OVER'), sql) @@ -718,14 +731,33 @@ class CommaClause(Clause): class EnclosedClause(CommaClause): """One or more Node objects joined by commas and enclosed in parens.""" parens = True +Tuple = EnclosedClause class Window(Node): - def __init__(self, partition_by=None, order_by=None): + CURRENT_ROW = 'CURRENT ROW' + + def __init__(self, partition_by=None, order_by=None, start=None, end=None): super(Window, self).__init__() self.partition_by = partition_by self.order_by = order_by + self.start = start + self.end = end + if self.start is None and self.end is not None: + raise ValueError('Cannot specify WINDOW end without start.') self._alias = self._alias or 'w' + @staticmethod + def following(value=None): + if value is None: + return SQL('UNBOUNDED FOLLOWING') + return SQL('%d FOLLOWING' % value) + + @staticmethod + def preceding(value=None): + if value is None: + return SQL('UNBOUNDED PRECEDING') + return SQL('%d PRECEDING' % value) + def __sql__(self): over_clauses = [] if self.partition_by: @@ -736,6 +768,14 @@ def __sql__(self): over_clauses.append(Clause( SQL('ORDER BY'), CommaClause(*self.order_by))) + if self.start is not None and self.end is not None: + over_clauses.append(Clause( + SQL('RANGE BETWEEN'), + self.start, + SQL('AND'), + self.end)) + elif self.start is not None: + over_clauses.append(Clause(SQL('RANGE'), self.start)) return EnclosedClause(Clause(*over_clauses)) def clone_base(self): @@ -877,7 +917,7 @@ class Field(Node): def __init__(self, null=False, index=False, unique=False, verbose_name=None, help_text=None, db_column=None, default=None, choices=None, primary_key=False, sequence=None, - constraints=None, schema=None): + constraints=None, schema=None, undeclared=False): self.null = null self.index = index self.unique = unique @@ -890,6 +930,7 @@ def __init__(self, null=False, index=False, unique=False, self.sequence = sequence # Name of sequence, e.g. foo_id_seq. self.constraints = constraints # List of column constraints. self.schema = schema # Name of schema, e.g. 'public'. + self.undeclared = undeclared # Whether this field is part of schema. # Used internally for recovering the order in which Fields were defined # on the Model class. @@ -914,6 +955,7 @@ def clone_base(self, **kwargs): sequence=self.sequence, constraints=self.constraints, schema=self.schema, + undeclared=self.undeclared, **kwargs) if self._is_bound: inst.name = self.name @@ -1021,6 +1063,12 @@ def __init__(self, *args, **kwargs): class _AutoPrimaryKeyField(PrimaryKeyField): _column_name = None + def __init__(self, *args, **kwargs): + if 'undeclared' in kwargs and not kwargs['undeclared']: + raise ValueError('%r must be created with undeclared=True.' % self) + kwargs['undeclared'] = True + super(_AutoPrimaryKeyField, self).__init__(*args, **kwargs) + def add_to_class(self, model_class, name): if name != self._column_name: raise ValueError('%s must be named `%s`.' % (type(self), name)) @@ -1042,6 +1090,7 @@ def __init__(self, max_digits=10, decimal_places=5, auto_round=False, self.decimal_places = decimal_places self.auto_round = auto_round self.rounding = rounding or decimal.DefaultContext.rounding + self._exp = decimal.Decimal(10) ** (-self.decimal_places) super(DecimalField, self).__init__(*args, **kwargs) def clone_base(self, **kwargs): @@ -1059,10 +1108,10 @@ def db_value(self, value): D = decimal.Decimal if not value: return value if value is None else D(0) - if self.auto_round: - exp = D(10) ** (-self.decimal_places) - rounding = self.rounding - return D(str(value)).quantize(exp, rounding=rounding) + elif self.auto_round or not isinstance(value, D): + value = D(str(value)) + if value.is_normal() and self.auto_round: + value = value.quantize(self._exp, rounding=self.rounding) return value def python_value(self, value): @@ -1075,7 +1124,10 @@ def coerce_to_unicode(s, encoding='utf-8'): if isinstance(s, unicode_type): return s elif isinstance(s, string_type): - return s.decode(encoding) + try: + return s.decode(encoding) + except UnicodeDecodeError: + return s return unicode_type(s) class CharField(Field): @@ -1602,6 +1654,7 @@ class QueryCompiler(object): JOIN.LEFT_OUTER: 'LEFT OUTER JOIN', JOIN.RIGHT_OUTER: 'RIGHT OUTER JOIN', JOIN.FULL: 'FULL JOIN', + JOIN.CROSS: 'CROSS JOIN', } alias_map_class = AliasMap @@ -1701,34 +1754,31 @@ def _parse_field(self, node, alias_map, conv): def _parse_compound_select_query(self, node, alias_map, conv): csq = 'compound_select_query' - if node.rhs._node_type == csq and node.lhs._node_type != csq: - first_q, second_q = node.rhs, node.lhs - inv = True - else: - first_q, second_q = node.lhs, node.rhs - inv = False + lhs, rhs = node.lhs, node.rhs + inv = rhs._node_type == csq and lhs._node_type != csq + if inv: + lhs, rhs = rhs, lhs new_map = self.alias_map_class() - if first_q._node_type == csq: + if lhs._node_type == csq: new_map._counter = alias_map._counter - first, first_p = self.generate_select(first_q, new_map) - second, second_p = self.generate_select( - second_q, - self.calculate_alias_map(second_q, new_map)) - - if inv: - l, lp, r, rp = second, second_p, first, first_p - else: - l, lp, r, rp = first, first_p , second, second_p + sql1, p1 = self.generate_select(lhs, new_map) + sql2, p2 = self.generate_select(rhs, self.calculate_alias_map(rhs, + new_map)) # We add outer parentheses in the event the compound query is used in # the `from_()` clause, in which case we'll need them. if node.database.compound_select_parentheses: - sql = '((%s) %s (%s))' % (l, node.operator, r) - else: - sql = '(%s %s %s)' % (l, node.operator, r) - return sql, lp + rp + if lhs._node_type != csq: + sql1 = '(%s)' % sql1 + if rhs._node_type != csq: + sql2 = '(%s)' % sql2 + + if inv: + sql1, p1, sql2, p2 = sql2, p2, sql1, p1 + + return '(%s %s %s)' % (sql1, node.operator, sql2), (p1 + p2) def _parse_select_query(self, node, alias_map, conv): clone = node.clone() @@ -1848,10 +1898,11 @@ def generate_joins(self, joins, model_class, alias_map): for join in joins[curr]: src = curr dest = join.dest + join_type = join.get_join_type() if isinstance(join.on, (Expression, Func, Clause, Entity)): # Clear any alias on the join expression. constraint = join.on.clone().alias() - else: + elif join_type != JOIN.CROSS: metadata = join.metadata if metadata.is_backref: fk_model = join.dest @@ -1877,13 +1928,12 @@ def generate_joins(self, joins, model_class, alias_map): q.append(dest) dest_n = dest.as_entity().alias(alias_map[dest]) - join_type = join.get_join_type() - if join_type in self.join_map: - join_sql = SQL(self.join_map[join_type]) + join_sql = SQL(self.join_map.get(join_type) or join_type) + if join_type == JOIN.CROSS: + clauses.append(Clause(join_sql, dest_n)) else: - join_sql = SQL(join_type) - clauses.append( - Clause(join_sql, dest_n, SQL('ON'), constraint)) + clauses.append(Clause(join_sql, dest_n, SQL('ON'), + constraint)) return clauses @@ -1912,15 +1962,6 @@ def generate_select(self, query, alias_map=None): else: clauses.append(CommaClause(*query._from)) - if query._windows is not None: - clauses.append(SQL('WINDOW')) - clauses.append(CommaClause(*[ - Clause( - SQL(window._alias), - SQL('AS'), - window.__sql__()) - for window in query._windows])) - join_clauses = self.generate_joins(query._joins, model, alias_map) if join_clauses: clauses.extend(join_clauses) @@ -1934,19 +1975,26 @@ def generate_select(self, query, alias_map=None): if query._having: clauses.extend([SQL('HAVING'), query._having]) + if query._windows is not None: + clauses.append(SQL('WINDOW')) + clauses.append(CommaClause(*[ + Clause( + SQL(window._alias), + SQL('AS'), + window.__sql__()) + for window in query._windows])) + if query._order_by: clauses.extend([SQL('ORDER BY'), CommaClause(*query._order_by)]) if query._limit is not None or (query._offset and db.limit_max): limit = query._limit if query._limit is not None else db.limit_max - clauses.append(SQL('LIMIT %s' % limit)) + clauses.append(SQL('LIMIT %d' % limit)) if query._offset is not None: - clauses.append(SQL('OFFSET %s' % query._offset)) + clauses.append(SQL('OFFSET %d' % query._offset)) - for_update, no_wait = query._for_update - if for_update: - stmt = 'FOR UPDATE NOWAIT' if no_wait else 'FOR UPDATE' - clauses.append(SQL(stmt)) + if query._for_update: + clauses.append(SQL(query._for_update)) return self.build_query(clauses, alias_map) @@ -2738,12 +2786,14 @@ def orwhere(self, *expressions): @returns_clone def join(self, dest, join_type=None, on=None): src = self._query_ctx - if not on: - require_join_condition = ( + if on is None: + require_join_condition = join_type != JOIN.CROSS and ( isinstance(dest, SelectQuery) or (isclass(dest) and not src._meta.rel_exists(dest))) if require_join_condition: raise ValueError('A join condition must be specified.') + elif join_type == JOIN.CROSS: + raise ValueError('A CROSS join cannot have a constraint.') elif isinstance(on, basestring): on = src._meta.fields[on] self._joins.setdefault(src, []) @@ -2917,7 +2967,7 @@ def __init__(self, model_class, *selection): self._limit = None self._offset = None self._distinct = False - self._for_update = (False, False) + self._for_update = None self._naive = False self._tuples = False self._dicts = False @@ -3024,7 +3074,12 @@ def distinct(self, is_distinct=True): @returns_clone def for_update(self, for_update=True, nowait=False): - self._for_update = (for_update, nowait) + self._for_update = 'FOR UPDATE NOWAIT' if for_update and nowait else \ + 'FOR UPDATE' if for_update else None + + @returns_clone + def with_lock(self, lock_type='UPDATE'): + self._for_update = ('FOR %s' % lock_type) if lock_type else None @returns_clone def naive(self, naive=True): @@ -3578,6 +3633,7 @@ def __init__(self, database, threadlocals=True, autocommit=True, self.field_overrides = merge_dict(self.field_overrides, fields or {}) self.op_overrides = merge_dict(self.op_overrides, ops or {}) + self.exception_wrapper = ExceptionWrapper(self.exceptions) def init(self, database, **connect_kwargs): if not self.is_closed(): @@ -3586,19 +3642,15 @@ def init(self, database, **connect_kwargs): self.database = database self.connect_kwargs.update(connect_kwargs) - def exception_wrapper(self): - return ExceptionWrapper(self.exceptions) - def connect(self): with self._conn_lock: if self.deferred: - raise Exception('Error, database not properly initialized ' - 'before opening connection') - with self.exception_wrapper(): - self._local.conn = self._connect( - self.database, - **self.connect_kwargs) - self._local.closed = False + raise OperationalError('Database has not been initialized') + if not self._local.closed: + raise OperationalError('Connection already open') + self._local.conn = self._create_connection() + self._local.closed = False + with self.exception_wrapper: self.initialize_connection(self._local.conn) def initialize_connection(self, conn): @@ -3609,7 +3661,7 @@ def close(self): if self.deferred: raise Exception('Error, database not properly initialized ' 'before closing connection') - with self.exception_wrapper(): + with self.exception_wrapper: self._close(self._local.conn) self._local.closed = True @@ -3622,6 +3674,10 @@ def get_conn(self): self.connect() return self._local.conn + def _create_connection(self): + with self.exception_wrapper: + return self._connect(self.database, **self.connect_kwargs) + def is_closed(self): return self._local.closed @@ -3677,12 +3733,12 @@ def execute(self, clause): def execute_sql(self, sql, params=None, require_commit=True): logger.debug((sql, params)) - with self.exception_wrapper(): + with self.exception_wrapper: cursor = self.get_cursor() try: cursor.execute(sql, params or ()) except Exception: - if self.get_autocommit() and self.autorollback: + if self.autorollback and self.get_autocommit(): self.rollback() raise else: @@ -3694,10 +3750,12 @@ def begin(self): pass def commit(self): - self.get_conn().commit() + with self.exception_wrapper: + self.get_conn().commit() def rollback(self): - self.get_conn().rollback() + with self.exception_wrapper: + self.get_conn().rollback() def set_autocommit(self, autocommit): self._local.autocommit = autocommit @@ -3716,8 +3774,10 @@ def pop_execution_context(self): def execution_context_depth(self): return len(self._local.context_stack) - def execution_context(self, with_transaction=True): - return ExecutionContext(self, with_transaction=with_transaction) + def execution_context(self, with_transaction=True, transaction_type=None): + return ExecutionContext(self, with_transaction, transaction_type) + + __call__ = execution_context def push_transaction(self, transaction): self._local.transactions.append(transaction) @@ -3728,23 +3788,17 @@ def pop_transaction(self): def transaction_depth(self): return len(self._local.transactions) - def transaction(self): - return transaction(self) - - def commit_on_success(self, func): - @wraps(func) - def inner(*args, **kwargs): - with self.transaction(): - return func(*args, **kwargs) - return inner + def transaction(self, transaction_type=None): + return transaction(self, transaction_type) + commit_on_success = property(transaction) def savepoint(self, sid=None): if not self.savepoints: raise NotImplementedError return savepoint(self, sid) - def atomic(self): - return _atomic(self) + def atomic(self, transaction_type=None): + return _atomic(self, transaction_type) def get_tables(self, schema=None): raise NotImplementedError @@ -3839,6 +3893,13 @@ def get_noop_sql(self): def get_binary_type(self): return binary_construct +def __pragma__(name): + def __get__(self): + return self.pragma(name) + def __set__(self, value): + return self.pragma(name, value) + return property(__get__, __set__) + class SqliteDatabase(Database): compiler_class = SqliteQueryCompiler field_overrides = { @@ -3888,8 +3949,28 @@ def _set_pragmas(self, conn): cursor.execute('PRAGMA %s = %s;' % (pragma, value)) cursor.close() - def begin(self, lock_type='DEFERRED'): - self.execute_sql('BEGIN %s' % lock_type, require_commit=False) + def pragma(self, key, value=SENTINEL): + sql = 'PRAGMA %s' % key + if value is not SENTINEL: + sql += ' = %s' % value + return self.execute_sql(sql).fetchone() + + cache_size = __pragma__('cache_size') + foreign_keys = __pragma__('foreign_keys') + journal_mode = __pragma__('journal_mode') + journal_size_limit = __pragma__('journal_size_limit') + mmap_size = __pragma__('mmap_size') + page_size = __pragma__('page_size') + read_uncommitted = __pragma__('read_uncommitted') + synchronous = __pragma__('synchronous') + wal_autocheckpoint = __pragma__('wal_autocheckpoint') + + def begin(self, lock_type=None): + statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' + self.execute_sql(statement, require_commit=False) + + def transaction(self, transaction_type=None): + return transaction_sqlite(self, transaction_type) def create_foreign_key(self, model_class, field, constraint=None): raise OperationalError('SQLite does not support ALTER TABLE ' @@ -4204,6 +4285,7 @@ def get_binary_type(self): class _callable_context_manager(object): + __slots__ = () def __call__(self, fn): @wraps(fn) def inner(*args, **kwargs): @@ -4212,9 +4294,10 @@ def inner(*args, **kwargs): return inner class ExecutionContext(_callable_context_manager): - def __init__(self, database, with_transaction=True): + def __init__(self, database, with_transaction=True, transaction_type=None): self.database = database self.with_transaction = with_transaction + self.transaction_type = transaction_type self.connection = None def __enter__(self): @@ -4260,41 +4343,42 @@ def __exit__(self, exc_type, exc_val, exc_tb): model._meta.database = self._orig[i] class _atomic(_callable_context_manager): - def __init__(self, db): + __slots__ = ('db', 'transaction_type', 'context_manager') + def __init__(self, db, transaction_type=None): self.db = db + self.transaction_type = transaction_type def __enter__(self): if self.db.transaction_depth() == 0: - self._helper = self.db.transaction() + self.context_manager = self.db.transaction(self.transaction_type) else: - self._helper = self.db.savepoint() - return self._helper.__enter__() + self.context_manager = self.db.savepoint() + return self.context_manager.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): - return self._helper.__exit__(exc_type, exc_val, exc_tb) + return self.context_manager.__exit__(exc_type, exc_val, exc_tb) class transaction(_callable_context_manager): - def __init__(self, db): + __slots__ = ('db', 'autocommit', 'transaction_type') + def __init__(self, db, transaction_type=None): self.db = db + self.transaction_type = transaction_type def _begin(self): self.db.begin() def commit(self, begin=True): self.db.commit() - if begin: - self._begin() + if begin: self._begin() def rollback(self, begin=True): self.db.rollback() - if begin: - self._begin() + if begin: self._begin() def __enter__(self): - self._orig = self.db.get_autocommit() + self.autocommit = self.db.get_autocommit() self.db.set_autocommit(False) - if self.db.transaction_depth() == 0: - self._begin() + if self.db.transaction_depth() == 0: self._begin() self.db.push_transaction(self) return self @@ -4309,10 +4393,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.rollback(False) raise finally: - self.db.set_autocommit(self._orig) + self.db.set_autocommit(self.autocommit) self.db.pop_transaction() class savepoint(_callable_context_manager): + __slots__ = ('db', 'sid', 'quoted_sid', 'autocommit') def __init__(self, db, sid=None): self.db = db _compiler = db.compiler() @@ -4329,7 +4414,7 @@ def rollback(self): self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) def __enter__(self): - self._orig_autocommit = self.db.get_autocommit() + self.autocommit = self.db.get_autocommit() self.db.set_autocommit(False) self._execute('SAVEPOINT %s;' % self.quoted_sid) return self @@ -4345,19 +4430,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.rollback() raise finally: - self.db.set_autocommit(self._orig_autocommit) + self.db.set_autocommit(self.autocommit) + +class transaction_sqlite(transaction): + __slots__ = () + def _begin(self): + self.db.begin(lock_type=self.transaction_type) class savepoint_sqlite(savepoint): + __slots__ = ('isolation_level',) def __enter__(self): conn = self.db.get_conn() # For sqlite, the connection's isolation_level *must* be set to None. # The act of setting it, though, will break any existing savepoints, # so only write to it if necessary. if conn.isolation_level is not None: - self._orig_isolation_level = conn.isolation_level + self.isolation_level = conn.isolation_level conn.isolation_level = None else: - self._orig_isolation_level = None + self.isolation_level = None return super(savepoint_sqlite, self).__enter__() def __exit__(self, exc_type, exc_val, exc_tb): @@ -4365,8 +4456,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): return super(savepoint_sqlite, self).__exit__( exc_type, exc_val, exc_tb) finally: - if self._orig_isolation_level is not None: - self.db.get_conn().isolation_level = self._orig_isolation_level + if self.isolation_level is not None: + self.db.get_conn().isolation_level = self.isolation_level class FieldProxy(Field): def __init__(self, alias, field_instance): @@ -4465,7 +4556,8 @@ class ModelOptions(object): def __init__(self, cls, database=None, db_table=None, db_table_func=None, indexes=None, order_by=None, primary_key=None, table_alias=None, constraints=None, schema=None, - validate_backrefs=True, only_save_dirty=False, **kwargs): + validate_backrefs=True, only_save_dirty=False, + depends_on=None, **kwargs): self.model_class = cls self.name = cls.__name__.lower() self.fields = {} @@ -4492,6 +4584,7 @@ def __init__(self, cls, database=None, db_table=None, db_table_func=None, self.schema = schema self.validate_backrefs = validate_backrefs self.only_save_dirty = only_save_dirty + self.depends_on = depends_on self.auto_increment = None self.composite_key = False @@ -4529,7 +4622,7 @@ def _update_field_lists(self): set(self.fields.values()) | set((self.primary_key,))) self.declared_fields = [field for field in self.sorted_fields - if not isinstance(field, _AutoPrimaryKeyField)] + if not field.undeclared] def add_field(self, field): self.remove_field(field.name) diff --git a/playhouse/_speedups.pyx b/playhouse/_speedups.pyx index cb8ee0234..1a822e2e7 100644 --- a/playhouse/_speedups.pyx +++ b/playhouse/_speedups.pyx @@ -332,6 +332,9 @@ cdef _sort_models(model, set model_set, set seen, list accum): for foreign_key in model._meta.reverse_rel.values(): _sort_models(foreign_key.model_class, model_set, seen, accum) accum.append(model) + if model._meta.depends_on is not None: + for dependency in model._meta.depends_on: + _sort_models(dependency, model_set, seen, accum) def sort_models_topologically(models): cdef: diff --git a/playhouse/_sqlite_ext.pyx b/playhouse/_sqlite_ext.pyx index 6edfeb139..7426583d5 100644 --- a/playhouse/_sqlite_ext.pyx +++ b/playhouse/_sqlite_ext.pyx @@ -61,8 +61,7 @@ cpdef peewee_regexp(regex_str, value): if value is None or regex_str is None: return - regex = re.compile(regex_str, re.I) - if value and regex.search(value): + if value and re.search(regex_str, value, re.I): return True return False diff --git a/playhouse/apsw_ext.py b/playhouse/apsw_ext.py index a7f6f545b..f7b0b9830 100644 --- a/playhouse/apsw_ext.py +++ b/playhouse/apsw_ext.py @@ -111,7 +111,7 @@ def _execute_sql(self, cursor, sql, params): def execute_sql(self, sql, params=None, require_commit=True): logger.debug((sql, params)) - with self.exception_wrapper(): + with self.exception_wrapper: cursor = self.get_cursor() self._execute_sql(cursor, sql, params) return cursor diff --git a/playhouse/migrate.py b/playhouse/migrate.py index dc3b3893b..2c0237349 100644 --- a/playhouse/migrate.py +++ b/playhouse/migrate.py @@ -244,6 +244,10 @@ def add_column(self, table, column_name, field): field.rel_model._meta.db_table, field.to_field.db_column)) + if field.index or field.unique: + operations.append( + self.add_index(table, (column_name,), field.unique)) + return operations @operation @@ -442,7 +446,9 @@ def get_foreign_key_constraint(self, table, column_name): 'FROM information_schema.key_column_usage WHERE ' 'table_schema = DATABASE() AND ' 'table_name = %s AND ' - 'column_name = %s;'), + 'column_name = %s AND ' + 'referenced_table_name IS NOT NULL AND ' + 'referenced_column_name IS NOT NULL;'), (table, column_name)) result = cursor.fetchone() if not result: diff --git a/playhouse/pool.py b/playhouse/pool.py index 228017de2..6345623b5 100644 --- a/playhouse/pool.py +++ b/playhouse/pool.py @@ -60,7 +60,17 @@ def do_something(foo, bar): """ import heapq import logging +import threading import time +try: + from Queue import Queue +except ImportError: + from queue import Queue + +try: + from psycopg2 import extensions as pg_extensions +except ImportError: + pg_extensions = None from peewee import MySQLDatabase from peewee import PostgresqlDatabase @@ -75,25 +85,54 @@ def make_int(val): return val +class MaxConnectionsExceeded(ValueError): pass + + class PooledDatabase(object): def __init__(self, database, max_connections=20, stale_timeout=None, - **kwargs): + timeout=None, **kwargs): self.max_connections = make_int(max_connections) self.stale_timeout = make_int(stale_timeout) + self.timeout = make_int(timeout) + if self.timeout == 0: + self.timeout = float('inf') + self._closed = set() self._connections = [] self._in_use = {} - self._closed = set() self.conn_key = id + if self.timeout: + self._event = threading.Event() + self._ready_queue = Queue() + super(PooledDatabase, self).__init__(database, **kwargs) def init(self, database, max_connections=None, stale_timeout=None, - **connect_kwargs): + timeout=None, **connect_kwargs): super(PooledDatabase, self).init(database, **connect_kwargs) if max_connections is not None: self.max_connections = make_int(max_connections) if stale_timeout is not None: self.stale_timeout = make_int(stale_timeout) + if timeout is not None: + self.timeout = make_int(timeout) + if self.timeout == 0: + self.timeout = float('inf') + + def connect(self): + if self.timeout: + start = time.time() + while start + self.timeout > time.time(): + try: + super(PooledDatabase, self).connect() + except MaxConnectionsExceeded: + time.sleep(0.1) + else: + return + raise MaxConnectionsExceeded('Max connections exceeded, timed out ' + 'attempting to connect.') + else: + super(PooledDatabase, self).connect() def _connect(self, *args, **kwargs): while True: @@ -130,7 +169,7 @@ def _connect(self, *args, **kwargs): if conn is None: if self.max_connections and ( len(self._in_use) >= self.max_connections): - raise ValueError('Exceeded maximum connections.') + raise MaxConnectionsExceeded('Exceeded maximum connections.') conn = super(PooledDatabase, self)._connect(*args, **kwargs) ts = time.time() key = self.conn_key(conn) @@ -140,11 +179,17 @@ def _connect(self, *args, **kwargs): return conn def _is_stale(self, timestamp): + # Called on check-out and check-in to ensure the connection has + # not outlived the stale timeout. return (time.time() - timestamp) > self.stale_timeout def _is_closed(self, key, conn): return key in self._closed + def _can_reuse(self, conn): + # Called on check-in to make sure the connection can be re-used. + return True + def _close(self, conn, close_conn=False): key = self.conn_key(conn) if close_conn: @@ -156,9 +201,11 @@ def _close(self, conn, close_conn=False): if self.stale_timeout and self._is_stale(ts): logger.debug('Closing stale connection %s.', key) super(PooledDatabase, self)._close(conn) - else: + elif self._can_reuse(conn): logger.debug('Returning %s to pool.', key) heapq.heappush(self._connections, (ts, conn)) + else: + logger.debug('Closed %s.', key) def manual_close(self): """ @@ -195,6 +242,14 @@ def _is_closed(self, key, conn): closed = bool(conn.closed) return closed + def _can_reuse(self, conn): + txn_status = conn.get_transaction_status() + # Do not return connection in an error state, as subsequent queries + # will all fail. + if txn_status == pg_extensions.TRANSACTION_STATUS_INERROR: + conn.reset() + return True + class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): pass diff --git a/playhouse/postgres_ext.py b/playhouse/postgres_ext.py index 2482d9ebd..9c015dfeb 100644 --- a/playhouse/postgres_ext.py +++ b/playhouse/postgres_ext.py @@ -164,7 +164,7 @@ def contains_any(self, *items): class DateTimeTZField(DateTimeField): - db_field = 'datetime_tz' + db_field = 'timestamptz' class HStoreField(IndexedFieldMixin, Field): @@ -261,12 +261,17 @@ class TSVectorField(IndexedFieldMixin, TextField): db_field = 'tsvector' default_index_type = 'GIN' - def match(self, query): - return Expression(self, OP.TS_MATCH, fn.to_tsquery(query)) + def match(self, query, language=None): + params = (language, query) if language is not None else (query,) + return Expression(self, OP.TS_MATCH, fn.to_tsquery(*params)) -def Match(field, query): - return Expression(fn.to_tsvector(field), OP.TS_MATCH, fn.to_tsquery(query)) +def Match(field, query, language=None): + params = (language, query) if language is not None else (query,) + return Expression( + fn.to_tsvector(field), + OP.TS_MATCH, + fn.to_tsquery(*params)) OP.update( @@ -364,7 +369,7 @@ def execute_sql(self, sql, params=None, require_commit=True, use_named_cursor = (named_cursor or ( self.server_side_cursors and sql.lower().startswith('select'))) - with self.exception_wrapper(): + with self.exception_wrapper: if use_named_cursor: cursor = self.get_cursor(name=str(uuid.uuid1())) require_commit = False diff --git a/playhouse/pskel b/playhouse/pskel index 2be68e202..8da40bd08 100755 --- a/playhouse/pskel +++ b/playhouse/pskel @@ -88,8 +88,8 @@ if __name__ == '__main__': ao('-d', '--database', dest='database', default=':memory:') options, models = parser.parse_args() - print render_models( + print(render_models( models, engine=engine_mapping[options.engine], database=options.database, - logging=options.logging) + logging=options.logging)) diff --git a/playhouse/shortcuts.py b/playhouse/shortcuts.py index b84ded25c..c7c4b6ea2 100644 --- a/playhouse/shortcuts.py +++ b/playhouse/shortcuts.py @@ -122,7 +122,7 @@ def model_to_dict(model, recurse=True, backrefs=False, only=None, seen=seen, max_depth=max_depth - 1) else: - field_data = {} + field_data = None data[field.name] = field_data @@ -211,7 +211,7 @@ def execute_sql(self, sql, params=None, require_commit=True): except OperationalError: if not self.is_closed(): self.close() - with self.exception_wrapper(): + with self.exception_wrapper: cursor = self.get_cursor() cursor.execute(sql, params or ()) if require_commit and self.get_autocommit(): diff --git a/playhouse/sqlite_ext.py b/playhouse/sqlite_ext.py index 5018697f3..0ad0735a6 100644 --- a/playhouse/sqlite_ext.py +++ b/playhouse/sqlite_ext.py @@ -43,11 +43,6 @@ class Document(FTSModel): except ImportError: import json -try: - from vtfunc import TableFunction -except ImportError: - pass - from peewee import * from peewee import EnclosedClause from peewee import Entity @@ -700,7 +695,8 @@ def descendants(cls, node, depth=None, include_node=False): query = (model_class .select(model_class, cls.depth.alias('depth')) .join(cls, on=(primary_key == cls.id)) - .where(cls.root == node)) + .where(cls.root == node) + .naive()) if depth is not None: query = query.where(cls.depth == depth) elif not include_node: @@ -712,7 +708,8 @@ def ancestors(cls, node, depth=None, include_node=False): query = (model_class .select(model_class, cls.depth.alias('depth')) .join(cls, on=(primary_key == cls.root)) - .where(cls.id == node)) + .where(cls.id == node) + .naive()) if depth: query = query.where(cls.depth == depth) elif not include_node: @@ -953,20 +950,6 @@ def create_index(self, model_class, field_name, unique=False): return super(SqliteExtDatabase, self).create_index( model_class, field_name, unique) - def granular_transaction(self, lock_type='deferred'): - assert lock_type.lower() in ('deferred', 'immediate', 'exclusive') - return granular_transaction(self, lock_type) - - -class granular_transaction(transaction): - def __init__(self, db, lock_type='deferred'): - self.db = db - self.conn = self.db.get_conn() - self.lock_type = lock_type - - def _begin(self): - self.db.begin(self.lock_type) - OP.MATCH = 'match' SqliteExtDatabase.register_ops({ diff --git a/playhouse/sqliteq.py b/playhouse/sqliteq.py index ffeb87db7..edfa434e1 100644 --- a/playhouse/sqliteq.py +++ b/playhouse/sqliteq.py @@ -1,5 +1,6 @@ import logging import weakref +from threading import local as thread_local from threading import Event from threading import Thread try: @@ -11,6 +12,7 @@ import gevent from gevent import Greenlet as GThread from gevent.event import Event as GEvent + from gevent.local import local as greenlet_local from gevent.queue import Queue as GQueue except ImportError: GThread = GQueue = GEvent = None @@ -24,10 +26,16 @@ class ResultTimeout(Exception): pass +class WriterPaused(Exception): + pass + +class ShutdownException(Exception): + pass + class AsyncCursor(object): __slots__ = ('sql', 'params', 'commit', 'timeout', - '_event', '_cursor', '_exc', '_idx', '_rows') + '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') def __init__(self, event, sql, params, commit, timeout): self._event = event @@ -36,6 +44,7 @@ def __init__(self, event, sql, params, commit, timeout): self.commit = commit self.timeout = timeout self._cursor = self._exc = self._idx = self._rows = None + self._ready = False def set_result(self, cursor, exc=None): self._cursor = cursor @@ -51,14 +60,18 @@ def _wait(self, timeout=None): raise ResultTimeout('results not ready, timed out.') if self._exc is not None: raise self._exc + self._ready = True def __iter__(self): - self._wait() + if not self._ready: + self._wait() if self._exc is not None: raise self._exec return self def next(self): + if not self._ready: + self._wait() try: obj = self._rows[self._idx] except IndexError: @@ -70,12 +83,14 @@ def next(self): @property def lastrowid(self): - self._wait() + if not self._ready: + self._wait() return self._cursor.lastrowid @property def rowcount(self): - self._wait() + if not self._ready: + self._wait() return self._cursor.rowcount @property @@ -89,31 +104,99 @@ def fetchall(self): return list(self) # Iterating implies waiting until populated. def fetchone(self): - self._wait() + if not self._ready: + self._wait() try: return next(self) except StopIteration: return None +SHUTDOWN = StopIteration +PAUSE = object() +UNPAUSE = object() + + +class Writer(object): + __slots__ = ('database', 'queue') + + def __init__(self, database, queue): + self.database = database + self.queue = queue + + def run(self): + conn = self.database.get_conn() + try: + while True: + try: + if conn is None: # Paused. + if self.wait_unpause(): + conn = self.database.get_conn() + else: + conn = self.loop(conn) + except ShutdownException: + logger.info('writer received shutdown request, exiting.') + return + finally: + if conn is not None: + self.database._close(conn) + self.database._local.closed = True + + def wait_unpause(self): + obj = self.queue.get() + if obj is UNPAUSE: + logger.info('writer unpaused - reconnecting to database.') + return True + elif obj is SHUTDOWN: + raise ShutdownException() + elif obj is PAUSE: + logger.error('writer received pause, but is already paused.') + else: + obj.set_result(None, WriterPaused()) + logger.warning('writer paused, not handling %s', obj) + + def loop(self, conn): + obj = self.queue.get() + if isinstance(obj, AsyncCursor): + self.execute(obj) + elif obj is PAUSE: + logger.info('writer paused - closing database connection.') + self.database._close(conn) + self.database._local.closed = True + return + elif obj is UNPAUSE: + logger.error('writer received unpause, but is already running.') + elif obj is SHUTDOWN: + raise ShutdownException() + else: + logger.error('writer received unsupported object: %s', obj) + return conn -THREADLOCAL_ERROR_MESSAGE = ('threadlocals cannot be set to True when using ' - 'the Sqlite thread / queue database. All queries ' - 'are serialized through a single connection, so ' - 'allowing multiple threads to connect defeats ' - 'the purpose of this database.') -WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL journal ' - 'mode when using this feature. WAL mode allows ' - 'one or more readers to continue reading while ' - 'another connection writes to the database.') + def execute(self, obj): + logger.debug('received query %s', obj.sql) + try: + cursor = self.database._execute(obj.sql, obj.params, obj.commit) + except Exception as execute_err: + cursor = None + exc = execute_err # python3 is so fucking lame. + else: + exc = None + return obj.set_result(cursor, exc) class SqliteQueueDatabase(SqliteExtDatabase): - def __init__(self, database, use_gevent=False, autostart=False, readers=1, + WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' + 'journal mode when using this feature. WAL mode ' + 'allows one or more readers to continue reading ' + 'while another connection writes to the ' + 'database.') + + def __init__(self, database, use_gevent=False, autostart=True, queue_max_size=None, results_timeout=None, *args, **kwargs): - if kwargs.get('threadlocals'): - raise ValueError(THREADLOCAL_ERROR_MESSAGE) + if 'threadlocals' in kwargs and not kwargs['threadlocals']: + raise ValueError('"threadlocals" must be true to use the ' + 'SqliteQueueDatabase.') - kwargs['threadlocals'] = False + kwargs['threadlocals'] = True kwargs['check_same_thread'] = False # Ensure that journal_mode is WAL. This value is passed to the parent @@ -126,18 +209,20 @@ def __init__(self, database, use_gevent=False, autostart=False, readers=1, # execute_sql(), this is just a handy way to reference the real # implementation. Parent = super(SqliteQueueDatabase, self) - self.__execute_sql = Parent.execute_sql + self._execute = Parent.execute_sql # Call the parent class constructor with our modified pragmas. Parent.__init__(database, pragmas=pragmas, *args, **kwargs) self._autostart = autostart self._results_timeout = results_timeout - self._num_readers = readers - self._is_stopped = True + + # Get different objects depending on the threading implementation. self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) - self._create_queues_and_workers() + + # Create the writer thread, optionally starting it. + self._create_write_queue() if self._autostart: self.start() @@ -146,69 +231,47 @@ def get_thread_impl(self, use_gevent): def _validate_journal_mode(self, journal_mode=None, pragmas=None): if journal_mode and journal_mode.lower() != 'wal': - raise ValueError(WAL_MODE_ERROR_MESSAGE) + raise ValueError(self.WAL_MODE_ERROR_MESSAGE) if pragmas: pdict = dict((k.lower(), v) for (k, v) in pragmas) if pdict.get('journal_mode', 'wal').lower() != 'wal': - raise ValueError(WAL_MODE_ERROR_MESSAGE) + raise ValueError(self.WAL_MODE_ERROR_MESSAGE) return [(k, v) for (k, v) in pragmas if k != 'journal_mode'] + [('journal_mode', 'wal')] else: return [('journal_mode', 'wal')] - def _create_queues_and_workers(self): + def _create_write_queue(self): self._write_queue = self._thread_helper.queue() - self._read_queue = self._thread_helper.queue() - - target = self._run_worker_loop - self._writer = self._thread_helper.thread(target, self._write_queue) - self._readers = [self._thread_helper.thread(target, self._read_queue) - for _ in range(self._num_readers)] - - def _run_worker_loop(self, queue): - while True: - async_cursor = queue.get() - if async_cursor is StopIteration: - logger.info('worker shutting down.') - return - - logger.debug('received query %s', async_cursor.sql) - self._process_execution(async_cursor) - - def _process_execution(self, async_cursor): - try: - cursor = self.__execute_sql(async_cursor.sql, async_cursor.params, - async_cursor.commit) - except Exception as exc: - cursor = None - else: - exc = None - return async_cursor.set_result(cursor, exc) def queue_size(self): - return (self._write_queue.qsize(), self._read_queue.qsize()) + return self._write_queue.qsize() def execute_sql(self, sql, params=None, require_commit=True, timeout=None): + if not require_commit: + return self._execute(sql, params, require_commit=require_commit) + cursor = AsyncCursor( event=self._thread_helper.event(), sql=sql, params=params, commit=require_commit, timeout=self._results_timeout if timeout is None else timeout) - queue = self._write_queue if require_commit else self._read_queue - queue.put(cursor) + self._write_queue.put(cursor) return cursor def start(self): with self._conn_lock: if not self._is_stopped: return False + def run(): + writer = Writer(self, self._write_queue) + writer.run() + + self._writer = self._thread_helper.thread(run) self._writer.start() - for reader in self._readers: - reader.start() - logger.info('workers started.') self._is_stopped = False return True @@ -217,18 +280,27 @@ def stop(self): with self._conn_lock: if self._is_stopped: return False - self._write_queue.put(StopIteration) - for _ in self._readers: - self._read_queue.put(StopIteration) + self._write_queue.put(SHUTDOWN) self._writer.join() - for reader in self._readers: - reader.join() + self._is_stopped = True return True def is_stopped(self): with self._conn_lock: return self._is_stopped + def pause(self): + with self._conn_lock: + self._write_queue.put(PAUSE) + + def unpause(self): + with self._conn_lock: + self._write_queue.put(UNPAUSE) + + def __unsupported__(self, *args, **kwargs): + raise ValueError('This method is not supported by %r.' % type(self)) + atomic = transaction = savepoint = __unsupported__ + class ThreadHelper(object): __slots__ = ('queue_max_size',) @@ -249,7 +321,7 @@ def thread(self, fn, *args, **kwargs): class GreenletHelper(ThreadHelper): - __slots__ = ('queue_max_size',) + __slots__ = () def event(self): return GEvent() diff --git a/playhouse/tests/test_compound_queries.py b/playhouse/tests/test_compound_queries.py index b88ed8825..7f3fcd0de 100644 --- a/playhouse/tests/test_compound_queries.py +++ b/playhouse/tests/test_compound_queries.py @@ -163,8 +163,8 @@ def test_multiple_with_parentheses(self): compound = lhs | queries[2] sql, params = compound.sql() self.assertEqual(sql, ( - '((SELECT "t1"."alpha" FROM "alpha" AS t1) UNION ' - '(SELECT "t2"."alpha" FROM "alpha" AS t2)) UNION ' + '(SELECT "t1"."alpha" FROM "alpha" AS t1) UNION ' + '(SELECT "t2"."alpha" FROM "alpha" AS t2) UNION ' '(SELECT "t3"."alpha" FROM "alpha" AS t3)')) lhs = queries[0] @@ -172,8 +172,8 @@ def test_multiple_with_parentheses(self): sql, params = compound.sql() self.assertEqual(sql, ( '(SELECT "t3"."alpha" FROM "alpha" AS t3) UNION ' - '((SELECT "t1"."alpha" FROM "alpha" AS t1) UNION ' - '(SELECT "t2"."alpha" FROM "alpha" AS t2))')) + '(SELECT "t1"."alpha" FROM "alpha" AS t1) UNION ' + '(SELECT "t2"."alpha" FROM "alpha" AS t2)')) def test_inner_limit(self): compound_db.compound_select_parentheses = True diff --git a/playhouse/tests/test_helpers.py b/playhouse/tests/test_helpers.py index ce1573f9b..97f0eb247 100644 --- a/playhouse/tests/test_helpers.py +++ b/playhouse/tests/test_helpers.py @@ -66,6 +66,52 @@ def assert_precedes(X, Y): assert_precedes(A, E) +class TestDeclaredDependencies(PeeweeTestCase): + def test_declared_dependencies(self): + class A(Model): pass + class B(Model): + a = ForeignKeyField(A) + b = ForeignKeyField('self') + class NA(Model): + class Meta: + depends_on = (A, B) + class C(Model): + b = ForeignKeyField(B) + c = ForeignKeyField('self') + class Meta: + depends_on = (NA,) + class D1(Model): + na = ForeignKeyField(NA) + class Meta: + depends_on = (A, C) + class D2(Model): + class Meta: + depends_on = (NA, D1, C, B) + + models = [A, B, C, D1, D2] + ordered = list(models) + for pmodels in permutations(models): + ordering = sort_models_topologically(pmodels) + self.assertEqual(ordering, ordered) + + def test_declared_dependencies_simple(self): + class A(Model): pass + class B(Model): + class Meta: + depends_on = (A,) + class C(Model): + b = ForeignKeyField(B) # Implicit dependency. + class D(Model): + class Meta: + depends_on = (C,) + + models = [A, B, C, D] + ordered = list(models) + for pmodels in permutations(models): + ordering = sort_models_topologically(pmodels) + self.assertEqual(ordering, ordered) + + def permutations(xs): if not xs: yield [] diff --git a/playhouse/tests/test_migrate.py b/playhouse/tests/test_migrate.py index de45cfbb4..018d4186c 100644 --- a/playhouse/tests/test_migrate.py +++ b/playhouse/tests/test_migrate.py @@ -1,5 +1,6 @@ import datetime import os +from functools import partial from peewee import * from peewee import print_ @@ -59,6 +60,10 @@ class Page(Model): name = CharField(max_length=100, unique=True, null=True) user = ForeignKeyField(User, null=True, related_name='pages') +class Session(Model): + user = ForeignKeyField(User, unique=True, related_name='sessions') + updated_at = DateField(null=True) + class IndexModel(Model): first_name = CharField() last_name = CharField() @@ -75,6 +80,7 @@ class Meta: Tag, User, Page, + Session ] class BaseMigrationTestCase(object): @@ -123,8 +129,7 @@ def test_add_column(self): t2 = Tag.create(tag='t2') # Convenience function for generating `add_column` migrations. - def add_column(field_name, field_obj): - return self.migrator.add_column('tag', field_name, field_obj) + add_column = partial(self.migrator.add_column, 'tag') # Run the migration. migrate( @@ -360,6 +365,27 @@ def test_add_index(self): first_name='first', last_name='last') + def test_add_unique_column(self): + uf = CharField(default='', unique=True) + + # Run the migration. + migrate(self.migrator.add_column('tag', 'unique_field', uf)) + + # Create a new tag model to represent the fields we added. + class NewTag(Model): + tag = CharField() + unique_field = uf + + class Meta: + database = self.database + db_table = Tag._meta.db_table + + NewTag.create(tag='t1', unique_field='u1') + NewTag.create(tag='t2', unique_field='u2') + with self.database.atomic(): + self.assertRaises(IntegrityError, NewTag.create, tag='t3', + unique_field='u1') + def test_drop_index(self): # Create a unique index. self.test_add_index() @@ -493,6 +519,19 @@ def test_rename_foreign_key(self): self.assertEqual(foreign_key.dest_column, 'id') self.assertEqual(foreign_key.dest_table, 'users') + def test_rename_unique_foreign_key(self): + migrate(self.migrator.rename_column('session', 'user_id', 'huey_id')) + columns = self.database.get_columns('session') + self.assertEqual( + sorted(column.name for column in columns), + ['huey_id', 'id', 'updated_at']) + + foreign_keys = self.database.get_foreign_keys('session') + self.assertEqual(len(foreign_keys), 1) + foreign_key = foreign_keys[0] + self.assertEqual(foreign_key.column, 'huey_id') + self.assertEqual(foreign_key.dest_column, 'id') + self.assertEqual(foreign_key.dest_table, 'users') class SqliteMigrationTestCase(BaseMigrationTestCase, PeeweeTestCase): database = sqlite_db diff --git a/playhouse/tests/test_models.py b/playhouse/tests/test_models.py index 80a7eac1d..c43097b9a 100644 --- a/playhouse/tests/test_models.py +++ b/playhouse/tests/test_models.py @@ -5,6 +5,7 @@ from peewee import * from peewee import ModelOptions +from peewee import sqlite3 from playhouse.tests.base import compiler from playhouse.tests.base import database_initializer from playhouse.tests.base import ModelTestCase @@ -18,6 +19,7 @@ in_memory_db = database_initializer.get_in_memory_database() +supports_tuples = sqlite3.sqlite_version_info >= (3, 15, 0) class GCModel(Model): name = CharField(unique=True) @@ -2308,3 +2310,20 @@ class Meta: m2 = qm2.get() self.assertEqual(m2.ids, '1,2,3') + + +@skip_unless( + lambda: (isinstance(test_db, PostgresqlDatabase) or + (isinstance(test_db, SqliteDatabase) and supports_tuples))) +class TestTupleComparison(ModelTestCase): + requires = [User] + + def test_tuples(self): + ua = User.create(username='user-a') + ub = User.create(username='user-b') + uc = User.create(username='user-c') + query = User.select().where( + Tuple(User.username, User.id) == ('user-b', ub.id)) + self.assertEqual(query.count(), 1) + obj = query.get() + self.assertEqual(obj, ub) diff --git a/playhouse/tests/test_pool.py b/playhouse/tests/test_pool.py index 459e7fd96..faa62ceaa 100644 --- a/playhouse/tests/test_pool.py +++ b/playhouse/tests/test_pool.py @@ -108,8 +108,10 @@ def open_conn(): def test_max_conns(self): for i in range(self.db.max_connections): + self.db._local.closed = True self.db.connect() self.assertEqual(self.db.get_conn(), i + 1) + self.db._local.closed = True self.assertRaises(ValueError, self.db.connect) def test_stale_timeout(self): @@ -215,6 +217,7 @@ def test_connect_cascade(self): self.assertEqual(db._connections, [(now, 4)]) # Since conn 4 is closed, we will open a new conn. + db._local.closed = True # Pretend we're in a different thread. db.connect() self.assertEqual(db.get_conn(), 5) self.assertEqual(sorted(db._in_use.keys()), [3, 5]) @@ -414,3 +417,15 @@ def test_execution_context(self): for number in Number.select().order_by(Number.value)] self.assertEqual(numbers, [1, 3, 4, 5]) + + def test_bad_connection(self): + pooled_db.connect() + try: + pooled_db.execute_sql('select 1/0') + except Exception as exc: + pass + pooled_db.close() + + pooled_db.connect() # Re-connect. + pooled_db.execute_sql('select 1') # Can execute queries. + pooled_db.close() diff --git a/playhouse/tests/test_pysqlite_ext.py b/playhouse/tests/test_pysqlite_ext.py new file mode 100644 index 000000000..22ef451ed --- /dev/null +++ b/playhouse/tests/test_pysqlite_ext.py @@ -0,0 +1,123 @@ +from peewee import * +from playhouse.pysqlite_ext import Database +from playhouse.tests.base import ModelTestCase + + +db = Database(':memory:') + +class User(Model): + username = CharField() + + class Meta: + database = db + + +class TestPysqliteDatabase(ModelTestCase): + requires = [ + User, + ] + + def tearDown(self): + super(TestPysqliteDatabase, self).tearDown() + db.on_commit(None) + db.on_rollback(None) + db.on_update(None) + + def test_commit_hook(self): + state = {} + + @db.on_commit + def on_commit(): + state.setdefault('commits', 0) + state['commits'] += 1 + + user = User.create(username='u1') + self.assertEqual(state['commits'], 1) + + user.username = 'u1-e' + user.save() + self.assertEqual(state['commits'], 2) + + with db.atomic(): + User.create(username='u2') + User.create(username='u3') + User.create(username='u4') + self.assertEqual(state['commits'], 2) + + self.assertEqual(state['commits'], 3) + + with db.atomic() as txn: + User.create(username='u5') + txn.rollback() + + self.assertEqual(state['commits'], 3) + self.assertEqual(User.select().count(), 4) + + def test_rollback_hook(self): + state = {} + + @db.on_rollback + def on_rollback(): + state.setdefault('rollbacks', 0) + state['rollbacks'] += 1 + + user = User.create(username='u1') + self.assertEqual(state, {'rollbacks': 1}) + + with db.atomic() as txn: + User.create(username='u2') + txn.rollback() + self.assertEqual(state['rollbacks'], 2) + + self.assertEqual(state['rollbacks'], 2) + + def test_update_hook(self): + state = [] + + @db.on_update + def on_update(query, db, table, rowid): + state.append((query, db, table, rowid)) + + u = User.create(username='u1') + u.username = 'u2' + u.save() + + self.assertEqual(state, [ + ('INSERT', 'main', 'user', 1), + ('UPDATE', 'main', 'user', 1), + ]) + + with db.atomic(): + User.create(username='u3') + User.create(username='u4') + u.delete_instance() + self.assertEqual(state, [ + ('INSERT', 'main', 'user', 1), + ('UPDATE', 'main', 'user', 1), + ('INSERT', 'main', 'user', 2), + ('INSERT', 'main', 'user', 3), + ('DELETE', 'main', 'user', 1), + ]) + + self.assertEqual(len(state), 5) + + def test_udf(self): + @db.func() + def backwards(s): + return s[::-1] + + @db.func() + def titled(s): + return s.title() + + query = db.execute_sql('SELECT titled(backwards(?));', ('hello',)) + result, = query.fetchone() + self.assertEqual(result, 'Olleh') + + def test_properties(self): + mem_used, mem_high = db.memory_used + self.assertTrue(mem_high >= mem_used) + self.assertFalse(mem_high == 0) + + conn = db.connection + self.assertTrue(conn.cache_used is not None) diff --git a/playhouse/tests/test_queries.py b/playhouse/tests/test_queries.py index 6b5ea61b4..e03bdfe63 100644 --- a/playhouse/tests/test_queries.py +++ b/playhouse/tests/test_queries.py @@ -518,6 +518,10 @@ def test_where_chaining_collapsing(self): sq = SelectQuery(User).where(~(User.id == 1)).where(User.id == 2).where(~(User.id == 3)) self.assertWhere(sq, '((NOT ("users"."id" = ?) AND ("users"."id" = ?)) AND NOT ("users"."id" = ?))', [1, 2, 3]) + def test_tuples(self): + sq = User.select().where(Tuple(User.id, User.username) == (1, 'hello')) + self.assertWhere(sq, '(("users"."id", "users"."username") = (?, ?))', [1, 'hello']) + def test_grouping(self): sq = SelectQuery(User).group_by(User.id) self.assertGroupBy(sq, '"users"."id"', []) @@ -1667,6 +1671,35 @@ def setUp(self): for int_v, float_v in self.data: NullModel.create(int_field=int_v, float_field=float_v) + def test_frame(self): + query = (NullModel + .select( + NullModel.float_field, + fn.AVG(NullModel.float_field).over( + partition_by=[NullModel.int_field], + start=Window.preceding(), + end=Window.following(2)))) + sql, params = query.sql() + self.assertEqual(sql, ( + 'SELECT "t1"."float_field", AVG("t1"."float_field") ' + 'OVER (PARTITION BY "t1"."int_field" RANGE BETWEEN ' + 'UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM "nullmodel" AS t1')) + self.assertEqual(params, []) + + query = (NullModel + .select( + NullModel.float_field, + fn.AVG(NullModel.float_field).over( + partition_by=[NullModel.int_field], + start=SQL('CURRENT ROW'), + end=Window.following()))) + sql, params = query.sql() + self.assertEqual(sql, ( + 'SELECT "t1"."float_field", AVG("t1"."float_field") ' + 'OVER (PARTITION BY "t1"."int_field" RANGE BETWEEN ' + 'CURRENT ROW AND UNBOUNDED FOLLOWING) FROM "nullmodel" AS t1')) + self.assertEqual(params, []) + def test_partition_unordered(self): query = (NullModel .select( diff --git a/playhouse/tests/test_shortcuts.py b/playhouse/tests/test_shortcuts.py index 800a8ba44..cad26d9ef 100644 --- a/playhouse/tests/test_shortcuts.py +++ b/playhouse/tests/test_shortcuts.py @@ -341,7 +341,7 @@ def test_recursive_fk(self): self.assertEqual(model_to_dict(root), { 'id': root.id, 'name': root.name, - 'parent': {}, + 'parent': None, }) with assert_query_count(0): @@ -356,7 +356,7 @@ def test_recursive_fk(self): 'children': [{'id': child.id, 'name': child.name}], 'id': root.id, 'name': root.name, - 'parent': {}, + 'parent': None, }) with assert_query_count(1): diff --git a/playhouse/tests/test_sqlite_ext.py b/playhouse/tests/test_sqlite_ext.py index 27de2c376..8e5803623 100644 --- a/playhouse/tests/test_sqlite_ext.py +++ b/playhouse/tests/test_sqlite_ext.py @@ -724,11 +724,11 @@ def test_function_decorator(self): self.assertEqual([x[0] for x in pq.tuples()], [ 'testing', 'chatting', ' foo']) - def test_granular_transaction(self): + def test_lock_type_transaction(self): conn = ext_db.get_conn() def test_locked_dbw(isolation_level): - with ext_db.granular_transaction(isolation_level): + with ext_db.transaction(isolation_level): Post.create(message='p1') # Will not be saved. conn2 = ext_db._connect(ext_db.database, **ext_db.connect_kwargs) conn2.execute('insert into post (message) values (?);', ('x1',)) @@ -737,7 +737,7 @@ def test_locked_dbw(isolation_level): self.assertRaises(sqlite3.OperationalError, test_locked_dbw, 'deferred') def test_locked_dbr(isolation_level): - with ext_db.granular_transaction(isolation_level): + with ext_db.transaction(isolation_level): Post.create(message='p2') other_db = database_initializer.get_database( 'sqlite', @@ -770,7 +770,7 @@ def test_locked_dbr(isolation_level): conn.rollback() ext_db.set_autocommit(True) - with ext_db.granular_transaction('deferred'): + with ext_db.transaction('deferred'): Post.create(message='p4') res = conn2.execute('select message from post order by message;') @@ -1200,3 +1200,20 @@ def test_tree_changes(self): ('westerns', 3), ('hard scifi', 4), ]) + + def test_id_not_overwritten(self): + class Node(BaseExtModel): + parent = ForeignKeyField('self', null=True) + name = CharField() + + NodeClosure = ClosureTable(Node) + ext_db.create_tables([Node, NodeClosure], True) + + root = Node.create(name='root') + c1 = Node.create(name='c1', parent=root) + c2 = Node.create(name='c2', parent=root) + + query = NodeClosure.descendants(root) + self.assertEqual(sorted([(n.id, n.name) for n in query]), + [(c1.id, 'c1'), (c2.id, 'c2')]) + ext_db.drop_tables([Node, NodeClosure]) diff --git a/playhouse/tests/test_sqliteq.py b/playhouse/tests/test_sqliteq.py index 168aac8d6..ce102508e 100644 --- a/playhouse/tests/test_sqliteq.py +++ b/playhouse/tests/test_sqliteq.py @@ -7,12 +7,14 @@ try: import gevent + from gevent.event import Event as GreenEvent except ImportError: gevent = None from peewee import * from playhouse.sqliteq import ResultTimeout from playhouse.sqliteq import SqliteQueueDatabase +from playhouse.sqliteq import WriterPaused from playhouse.tests.base import database_initializer from playhouse.tests.base import PeeweeTestCase from playhouse.tests.base import skip_if @@ -45,7 +47,7 @@ def setUp(self): self.db = get_db(**self.database_config) # Sanity check at startup. - self.assertEqual(self.db.queue_size(), (0, 0)) + self.assertEqual(self.db.queue_size(), 0) def tearDown(self): super(BaseTestQueueDatabase, self).tearDown() @@ -60,9 +62,15 @@ def tearDown(self): if os.path.exists(filename): os.unlink(filename) + def test_query_error(self): + self.db.start() + curs = self.db.execute_sql('foo bar baz') + self.assertRaises(OperationalError, curs.fetchone) + self.db.stop() + def test_query_execution(self): qr = User.select().execute() - self.assertEqual(self.db.queue_size(), (0, 1)) + self.assertEqual(self.db.queue_size(), 0) self.db.start() @@ -72,13 +80,16 @@ def test_query_execution(self): self.assertTrue(huey.id is not None) self.assertTrue(mickey.id is not None) - self.assertEqual(self.db.queue_size(), (0, 0)) + self.assertEqual(self.db.queue_size(), 0) self.db.stop() def create_thread(self, fn, *args): raise NotImplementedError + def create_event(self): + raise NotImplementedError + def test_multiple_threads(self): def create_rows(idx, nrows): for i in range(idx, idx + nrows): @@ -94,6 +105,60 @@ def create_rows(idx, nrows): self.assertEqual(User.select().count(), total) self.db.stop() + def test_pause(self): + event_a = self.create_event() + event_b = self.create_event() + + def create_user(name, event, expect_paused): + event.wait() + if expect_paused: + self.assertRaises(WriterPaused, lambda: User.create(name=name)) + else: + User.create(name=name) + + self.db.start() + + t_a = self.create_thread(create_user, 'a', event_a, True) + t_a.start() + t_b = self.create_thread(create_user, 'b', event_b, False) + t_b.start() + + User.create(name='c') + self.assertEqual(User.select().count(), 1) + + # Pause operations but preserve the writer thread/connection. + self.db.pause() + + event_a.set() + self.assertEqual(User.select().count(), 1) + t_a.join() + + self.db.unpause() + self.assertEqual(User.select().count(), 1) + + event_b.set() + t_b.join() + self.assertEqual(User.select().count(), 2) + + self.db.stop() + + def test_restart(self): + self.db.start() + User.create(name='a') + self.db.stop() + self.db._results_timeout = 0.0001 + + self.assertRaises(ResultTimeout, User.create, name='b') + self.assertEqual(User.select().count(), 1) + + self.db.start() # Will execute the pending "b" INSERT. + self.db._results_timeout = None + + User.create(name='c') + self.assertEqual(User.select().count(), 3) + self.assertEqual(sorted(u.name for u in User.select()), + ['a', 'b', 'c']) + def test_waiting(self): D = {} @@ -117,6 +182,22 @@ def get_users(): self.assertEqual(sorted(D), ['charlie', 'huey', 'users', 'zaizee']) + def test_next_method(self): + self.db.start() + + User.create(name='mickey') + User.create(name='huey') + query = iter(User.select().order_by(User.name)) + self.assertEqual(next(query).name, 'huey') + self.assertEqual(next(query).name, 'mickey') + self.assertRaises(StopIteration, lambda: next(query)) + + self.assertEqual( + next(self.db.execute_sql('PRAGMA journal_mode'))[0], + 'wal') + + self.db.stop() + class TestThreadedDatabaseThreads(BaseTestQueueDatabase, PeeweeTestCase): database_config = {'use_gevent': False} @@ -130,6 +211,9 @@ def create_thread(self, fn, *args): t.daemon = True return t + def create_event(self): + return threading.Event() + def test_timeout(self): @self.db.func() def slow(n): @@ -156,6 +240,9 @@ class TestThreadedDatabaseGreenlets(BaseTestQueueDatabase, PeeweeTestCase): def create_thread(self, fn, *args): return gevent.Greenlet(fn, *args) + def create_event(self): + return GreenEvent() + if __name__ == '__main__': unittest.main(argv=sys.argv) diff --git a/playhouse/tests/test_transactions.py b/playhouse/tests/test_transactions.py index 95513e3fe..a4a86dc1e 100644 --- a/playhouse/tests/test_transactions.py +++ b/playhouse/tests/test_transactions.py @@ -59,25 +59,26 @@ def test_atomic_nesting(self): rollback = db_mocks['rollback'] with _atomic(patched_db): - patched_db.transaction.assert_called_once_with() - begin.assert_called_once_with() + patched_db.transaction.assert_called_once_with(None) + begin.assert_called_once_with(lock_type=None) self.assertEqual(patched_db.savepoint.call_count, 0) with _atomic(patched_db): - patched_db.transaction.assert_called_once_with() - begin.assert_called_once_with() + patched_db.transaction.assert_called_once_with(None) + begin.assert_called_once_with(lock_type=None) patched_db.savepoint.assert_called_once_with() self.assertEqual(commit.call_count, 0) self.assertEqual(rollback.call_count, 0) with _atomic(patched_db): - patched_db.transaction.assert_called_once_with() - begin.assert_called_once_with() + (patched_db.transaction + .assert_called_once_with(None)) + begin.assert_called_once_with(lock_type=None) self.assertEqual( patched_db.savepoint.call_count, 2) - begin.assert_called_once_with() + begin.assert_called_once_with(lock_type=None) self.assertEqual(commit.call_count, 0) self.assertEqual(rollback.call_count, 0) diff --git a/setup.py b/setup.py index 0e5ead70e..4730853a4 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import platform import sys import warnings from distutils.core import setup @@ -11,6 +12,7 @@ setup_kwargs = {} cython_min_version = '0.22.1' + try: from Cython.Distutils import build_ext from Cython import __version__ as cython_version @@ -20,7 +22,11 @@ 'Cython does not seem to be installed. To enable Cython C ' 'extensions, install Cython >=' + cython_min_version + '.') else: - if StrictVersion(cython_version) < StrictVersion(cython_min_version): + if platform.python_implementation() != 'CPython': + cython_installed = False + warnings.warn('Cython C extensions disabled as you are not using ' + 'CPython.') + elif StrictVersion(cython_version) < StrictVersion(cython_min_version): cython_installed = False warnings.warn('Cython C extensions for peewee will NOT be built, ' 'because the installed Cython version ' @@ -71,6 +77,6 @@ 'Programming Language :: Python', 'Programming Language :: Python :: 3', ], - scripts = ['pwiz.py'], + scripts = ['pwiz.py', 'playhouse/pskel'], **setup_kwargs )