Skip to content

Commit

Permalink
Fix column->field name translation. Fixes #1437.
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles Leifer committed Jan 30, 2018
1 parent bd3ee2a commit 20c2a9f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
20 changes: 16 additions & 4 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -5761,6 +5761,7 @@ def _initialize_columns(self):
self.columns = []
self.converters = converters = [None] * self.ncols
self.fields = fields = [None] * self.ncols
self.translation = {}

for idx, description_item in enumerate(description):
column = description_item[0]
Expand Down Expand Up @@ -5814,12 +5815,14 @@ class ModelDictCursorWrapper(BaseModelCursorWrapper):
def process_row(self, row):
result = {}
columns, converters = self.columns, self.converters
fields = self.fields

for i in range(self.ncols):
attr = fields[i].name if fields[i] is not None else columns[i]
if converters[i] is not None:
result[columns[i]] = converters[i](row[i])
result[attr] = converters[i](row[i])
else:
result[columns[i]] = row[i]
result[attr] = row[i]

return result

Expand All @@ -5837,7 +5840,13 @@ def process_row(self, row):
class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper):
def initialize(self):
self._initialize_columns()
self.tuple_class = namedtuple('Row', self.columns)
attributes = []
for i in range(self.ncols):
if self.fields[i] is not None:
attributes.append(self.fields[i].name)
else:
attributes.append(self.columns[i])
self.tuple_class = namedtuple('Row', attributes)
self.constructor = lambda row: self.tuple_class(*row)


Expand Down Expand Up @@ -5915,7 +5924,10 @@ def process_row(self, row):
set_keys = set()
for idx, key in enumerate(self.column_keys):
instance = objects[key]
column = self.columns[idx]
if self.fields[idx] is not None:
column = self.fields[idx].name
else:
column = self.columns[idx]
value = row[idx]
if value is not None:
set_keys.add(key)
Expand Down
23 changes: 23 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class Color(TestModel):
is_neutral = BooleanField(default=False)


class Post(TestModel):
content = TextField(column_name='Content')
timestamp = DateTimeField(column_name='TimeStamp',
default=datetime.datetime.now)


class TestModelAPIs(ModelTestCase):
def add_user(self, username):
return User.create(username=username)
Expand All @@ -51,6 +57,23 @@ def do_test(n):
do_test(4)
self.assertRaises(AssertionError, do_test, 5)

@requires_models(Post)
def test_column_field_translation(self):
ts = datetime.datetime(2017, 2, 1, 13, 37)
ts2 = datetime.datetime(2017, 2, 2, 13, 37)
p = Post.create(content='p1', timestamp=ts)
p2 = Post.create(content='p2', timestamp=ts2)

p_db = Post.get(Post.content == 'p1')
self.assertEqual(p_db.content, 'p1')
self.assertEqual(p_db.timestamp, ts)

pd1, pd2 = Post.select().order_by(Post.id).dicts()
self.assertEqual(pd1['content'], 'p1')
self.assertEqual(pd1['timestamp'], ts)
self.assertEqual(pd2['content'], 'p2')
self.assertEqual(pd2['timestamp'], ts2)

@requires_models(User, Tweet)
def test_create(self):
with self.assertQueryCount(1):
Expand Down

0 comments on commit 20c2a9f

Please sign in to comment.