Skip to content

Commit

Permalink
Fixed #36088 -- Avoided unnecessary DEFAULT usage on bulk_create().
Browse files Browse the repository at this point in the history
When all values of a field with a db_default are DatabaseDefault, which is the
case most of the time, there is no point in specifying explicit DEFAULT for all
INSERT VALUES as that's what the database will do anyway if not specified.

In the case of Postgresql doing so can even be harmful as it prevents the usage
of the UNNEST strategy and in the case of Oracle, which doesn't support the
usage of the DEFAULT keyword, it unnecessarily requires providing literal
db defaults.
  • Loading branch information
charettes committed Jan 12, 2025
1 parent 8bee7fa commit 4e4f400
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 21 deletions.
11 changes: 0 additions & 11 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,21 +671,10 @@ async def acreate(self, **kwargs):
acreate.alters_data = True

def _prepare_for_bulk_create(self, objs):
from django.db.models.expressions import DatabaseDefault

connection = connections[self.db]
for obj in objs:
if not obj._is_pk_set():
# Populate new PK values.
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
if not connection.features.supports_default_keyword_in_bulk_insert:
for field in obj._meta.fields:
if field.generated:
continue
value = getattr(obj, field.attname)
if isinstance(value, DatabaseDefault):
setattr(obj, field.attname, field.db_default)

obj._prepare_related_fields_for_save(operation_name="bulk_create")

def _check_bulk_create_options(
Expand Down
49 changes: 40 additions & 9 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,23 +1811,54 @@ def as_sql(self):
on_conflict=self.query.on_conflict,
)
result = ["%s %s" % (insert_statement, qn(opts.db_table))]
fields = self.query.fields or [opts.pk]
result.append("(%s)" % ", ".join(qn(f.column) for f in fields))

if self.query.fields:
value_rows = [
[
self.prepare_value(field, self.pre_save_val(field, obj))
for field in fields
if fields := self.query.fields:
from django.db.models.expressions import DatabaseDefault

supports_default_keyword_in_bulk_insert = (
self.connection.features.supports_default_keyword_in_bulk_insert
)
value_cols = []
for field in list(fields):
field_prepare = partial(self.prepare_value, field)
field_pre_save = partial(self.pre_save_val, field)
field_values = [
field_prepare(field_pre_save(obj)) for obj in self.query.objs
]
for obj in self.query.objs
]
if field.has_db_default():
# If all values are DEFAULT don't include the field and its
# values in the query as they are redundant and could prevent
# optimizations. This cannot be done if we're dealing with the
# last field as INSERT statements require at least one.
if len(fields) > 1 and all(
isinstance(value, DatabaseDefault) for value in field_values
):
fields.remove(field)
continue
elif not supports_default_keyword_in_bulk_insert:
# If the field cannot be excluded from the INSERT for the
# reasons listed above and the backend doesn't support the
# DEFAULT keyword each values must be expanded into their
# underlying expressions.
prepared_db_default = field_prepare(field.db_default)
field_values = [
(
prepared_db_default
if isinstance(value, DatabaseDefault)
else value
)
for value in field_values
]
value_cols.append(field_values)
value_rows = list(zip(*value_cols))
result.append("(%s)" % ", ".join(qn(f.column) for f in fields))
else:
# An empty object.
value_rows = [
[self.connection.ops.pk_default_value()] for _ in self.query.objs
]
fields = [None]
result.append("(%s)" % qn(opts.pk.column))

# Currently the backends just accept values when generating bulk
# queries and generate their own placeholders. Doing that isn't
Expand Down
2 changes: 1 addition & 1 deletion tests/backends/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Square(models.Model):
root = models.IntegerField()
square = models.PositiveIntegerField()
square = models.PositiveIntegerField(db_default=9)

def __str__(self):
return "%s ** 2 == %s" % (self.root, self.square)
Expand Down
6 changes: 6 additions & 0 deletions tests/backends/postgresql/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ def test_unnest_eligible(self):
[Square(root=2, square=4), Square(root=3, square=9)]
)
self.assertIn("UNNEST", ctx[0]["sql"])

def test_unnest_eligible_db_default(self):
with self.assertNumQueries(1) as ctx:
squares = Square.objects.bulk_create([Square(root=3), Square(root=3)])
self.assertIn("UNNEST", ctx[0]["sql"])
self.assertEqual([square.square for square in squares], [9, 9])
6 changes: 6 additions & 0 deletions tests/bulk_create/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from decimal import Decimal

from django.db import models
from django.db.models.functions import Now
from django.utils import timezone

try:
Expand Down Expand Up @@ -141,3 +142,8 @@ class RelatedModel(models.Model):
name = models.CharField(max_length=15, null=True)
country = models.OneToOneField(Country, models.CASCADE, primary_key=True)
big_auto_fields = models.ManyToManyField(BigAutoFieldModel)


class DbDefaultModel(models.Model):
name = models.CharField(max_length=10)
created_at = models.DateTimeField(db_default=Now())
26 changes: 26 additions & 0 deletions tests/bulk_create/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.utils import timezone

from .models import (
BigAutoFieldModel,
Country,
DbDefaultModel,
FieldsWithDbColumns,
NoFields,
NullableFields,
Expand Down Expand Up @@ -840,3 +842,27 @@ def test_update_conflicts_unique_fields_update_fields_db_column(self):
{"rank": 2, "name": "d"},
],
)

def test_db_default_field_excluded(self):
# created_at is excluded when no db_default override is provided.
with self.assertNumQueries(1) as ctx:
DbDefaultModel.objects.bulk_create(
[DbDefaultModel(name="foo"), DbDefaultModel(name="bar")]
)
created_at_quoted_name = connection.ops.quote_name("created_at")
self.assertEqual(
ctx[0]["sql"].count(created_at_quoted_name),
(1 if connection.features.can_return_rows_from_bulk_insert else 0),
)
# created_at is included when a db_default override is provided.
with self.assertNumQueries(1) as ctx:
DbDefaultModel.objects.bulk_create(
[
DbDefaultModel(name="foo", created_at=timezone.now()),
DbDefaultModel(name="bar"),
]
)
self.assertEqual(
ctx[0]["sql"].count(created_at_quoted_name),
(2 if connection.features.can_return_rows_from_bulk_insert else 1),
)

0 comments on commit 4e4f400

Please sign in to comment.