Skip to content

Commit

Permalink
Draw column relations (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
maurerle authored Sep 16, 2024
2 parents b9b62f3 + 8fa193e commit 74d7bee
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 65 deletions.
8 changes: 4 additions & 4 deletions eralchemy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,10 @@ def check_table(name):
_relationships = [
r
for r in _relationships
if not exclude_tables_re.fullmatch(r.right_col)
and not exclude_tables_re.fullmatch(r.left_col)
and include_tables_re.fullmatch(r.right_col)
and include_tables_re.fullmatch(r.left_col)
if not exclude_tables_re.fullmatch(r.right_table)
and not exclude_tables_re.fullmatch(r.left_table)
and include_tables_re.fullmatch(r.right_table)
and include_tables_re.fullmatch(r.left_table)
]

def check_column(name):
Expand Down
69 changes: 44 additions & 25 deletions eralchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ def to_mermaid_er(self) -> str:

def to_dot(self) -> str:
base = ROW_TAGS.format(
' ALIGN="LEFT"',
' ALIGN="LEFT" {port}',
"{key_opening}{col_name}{key_closing} {type}{null}",
)
return base.format(
port=f'PORT="{self.name}"' if self.name else "",
key_opening="<u>" if self.is_key else "",
key_closing="</u>" if self.is_key else "",
col_name=FONT_TAGS.format(self.name),
Expand All @@ -129,7 +130,19 @@ class Relation(Drawable):
"""Represents a Relation in the intermediaty syntax."""

RE = re.compile(
r"(?P<left_name>\S+(\s*\S+)?)\s+(?P<left_cardinality>[*?+1])--(?P<right_cardinality>[*?+1])\s*(?P<right_name>\S+(\s*\S+)?)",
r"""
(?P<left_table>[^\s]+?)
(?:\.\"(?P<left_column>.+)\")?
\s*
(?P<left_cardinality>[*?+1])
--
(?P<right_cardinality>[*?+1])
\s*
(?P<right_table>[^\s]+?)
(?:\.\"(?P<right_column>.+)\")?
\s*$
""",
re.VERBOSE,
)
cardinalities = {"*": "0..N", "?": "{0,1}", "+": "1..N", "1": "1", "": None}
cardinalities_mermaid = {
Expand All @@ -145,57 +158,61 @@ class Relation(Drawable):

@staticmethod
def make_from_match(match: re.Match) -> Relation:
return Relation(
right_col=match.group("right_name"),
left_col=match.group("left_name"),
right_cardinality=match.group("right_cardinality"),
left_cardinality=match.group("left_cardinality"),
)
return Relation(**match.groupdict())

def __init__(
self,
right_col,
left_col,
right_table,
left_table,
right_cardinality=None,
left_cardinality=None,
right_column=None,
left_column=None,
):
if (
right_cardinality not in self.cardinalities.keys()
or left_cardinality not in self.cardinalities.keys()
):
raise ValueError(f"Cardinality should be in {self.cardinalities.keys()}")
self.right_col = right_col
self.left_col = left_col
self.right_table = right_table
self.right_column = right_column or ""
self.left_table = left_table
self.left_column = left_column or ""
self.right_cardinality = right_cardinality
self.left_cardinality = left_cardinality

def to_markdown(self) -> str:
return f"{self.left_col} {self.left_cardinality}--{self.right_cardinality} {self.right_col}"
return "{}{} {}--{} {}{}".format(
self.left_table,
"" if not self.left_column else f'."{self.left_column}"',
self.left_cardinality,
self.right_cardinality,
self.right_table,
"" if not self.right_column else f'."{self.right_column}"',
)

def to_mermaid(self) -> str:
normalized = (
Relation.cardinalities_mermaid.get(k, k)
for k in (
sanitize_mermaid(self.left_col),
sanitize_mermaid(self.left_table),
self.left_cardinality,
self.right_cardinality,
sanitize_mermaid(self.right_col),
sanitize_mermaid(self.right_table),
)
)
return '{} "{}" -- "{}" {}'.format(*normalized)

def to_mermaid_er(self) -> str:
left = Relation.cardinalities_crowfoot.get(
self.left_cardinality,
self.left_cardinality,
)
right = Relation.cardinalities_crowfoot.get(
self.right_cardinality,
self.right_cardinality,
)

left_col = sanitize_mermaid(self.left_col, is_er=True)
right_col = sanitize_mermaid(self.right_col, is_er=True)
left_col = sanitize_mermaid(self.left_table, is_er=True)
right_col = sanitize_mermaid(self.right_table, is_er=True)
return f"{left_col} {left}--{right} {right_col} : has"

def graphviz_cardinalities(self, card) -> str:
Expand All @@ -211,10 +228,10 @@ def to_dot(self) -> str:
cards.append("tail" + self.graphviz_cardinalities(self.left_cardinality))
if self.right_cardinality != "":
cards.append("head" + self.graphviz_cardinalities(self.right_cardinality))
return '"{}" -- "{}" [{}];'.format(
self.left_col,
self.right_col,
",".join(cards),
left_col = f':"{self.left_column}"' if self.left_column else ""
right_col = f':"{self.right_column}"' if self.right_column else ""
return (
f'"{self.left_table}"{left_col} -- "{self.right_table}"{right_col} [{",".join(cards)}];'
)

def __eq__(self, other: object) -> bool:
Expand All @@ -223,8 +240,10 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, Relation):
return False
other_inversed = Relation(
right_col=other.left_col,
left_col=other.right_col,
right_table=other.left_table,
right_column=other.left_column,
left_table=other.right_table,
left_column=other.right_column,
right_cardinality=other.left_cardinality,
left_cardinality=other.right_cardinality,
)
Expand Down
4 changes: 2 additions & 2 deletions eralchemy/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def update_models(

if isinstance(new_obj, Relation):
tables_names = [t.name for t in tables]
_check_colname_in_lst(new_obj.right_col, tables_names)
_check_colname_in_lst(new_obj.left_col, tables_names)
_check_colname_in_lst(new_obj.right_table, tables_names)
_check_colname_in_lst(new_obj.left_table, tables_names)
return current_table, tables, relations + [new_obj]

if isinstance(new_obj, Column):
Expand Down
6 changes: 4 additions & 2 deletions eralchemy/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def relation_to_intermediary(fk: sa.ForeignKey) -> Relation:
# if this is the case, we are not optional and must be unique
right_cardinality = "1" if check_all_compound_same_parent(fk) else "*"
return Relation(
right_col=format_name(fk.parent.table.fullname),
left_col=format_name(fk.column.table.fullname),
right_table=format_name(fk.parent.table.fullname),
right_column=format_name(fk.parent.name),
left_table=format_name(fk.column.table.fullname),
left_column=format_name(fk.column.name),
right_cardinality=right_cardinality,
left_cardinality="?" if fk.parent.nullable else "1",
)
Expand Down
18 changes: 11 additions & 7 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ class ExcludeWithSchema(Base):
)

relation = Relation(
right_col="parent",
left_col="child",
right_cardinality="?",
left_table="child",
left_column="parent_id",
right_table="parent",
right_column="id",
left_cardinality="*",
right_cardinality="?",
)

exclude_id = ERColumn(name="id", type="INTEGER", is_key=True)
Expand All @@ -82,8 +84,10 @@ class ExcludeWithSchema(Base):
)

exclude_relation = Relation(
right_col="parent",
left_col="exclude",
right_table="parent",
right_column="id",
left_table="exclude",
left_column="parent_id",
right_cardinality="?",
left_cardinality="*",
)
Expand Down Expand Up @@ -117,8 +121,8 @@ class ExcludeWithSchema(Base):
[exclude]
*id {label:"INTEGER"}
parent_id {label:"INTEGER"}
parent ?--* child
parent ?--* exclude
child."parent_id" *--? parent."id"
exclude."parent_id" *--? parent."id"
"""


Expand Down
27 changes: 12 additions & 15 deletions tests/test_intermediary_to_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest
from pygraphviz import AGraph

from eralchemy.cst import DOT_GRAPH_BEGINNING
from eralchemy.main import _intermediary_to_dot
from tests.common import (
child,
Expand All @@ -17,17 +16,13 @@
relation,
)

GRAPH_LAYOUT = DOT_GRAPH_BEGINNING + "%s }"
column_re = re.compile(
'\\<TR\\>\\<TD\\ ALIGN\\=\\"LEFT\\"\\>(.*)\\<\\/TD\\>\\<\\/TR\\>',
)
column_re = re.compile(r"\<TR\>\<TD\ ALIGN\=\"LEFT\"\ PORT\=\".+\">(.*)\<\/TD\>\<\/TR\>")
header_re = re.compile(
'\\<TR\\>\\<TD\\>\\<B\\>\\<FONT\\ POINT\\-SIZE\\=\\"16\\"\\>(.*)'
"\\<\\/FONT\\>\\<\\/B\\>\\<\\/TD\\>\\<\\/TR\\>",
r"\<TR\>\<TD\>\<B\>\<FONT\ POINT\-SIZE\=\"16\"\>(.*)" r"\<\/FONT\>\<\/B\>\<\/TD\>\<\/TR\>"
)
column_inside = re.compile(
"(?P<key_opening>.*)\\<FONT\\>(?P<name>.*)\\<\\/FONT\\>"
"(?P<key_closing>.*)\\ <FONT\\>\\ \\[(?P<type>.*)\\]\\<\\/FONT\\>",
r"(?P<key_opening>.*)\<FONT\>(?P<name>.*)\<\/FONT\>"
r"(?P<key_closing>.*)\<FONT\>\ \[(?P<type>.*)\]\<\/FONT\>"
)


Expand Down Expand Up @@ -79,7 +74,7 @@ def assert_column_well_rendered_to_dot(col):
col_parsed = column_inside.match(col_no_table[0])
assert col_parsed.group("key_opening") == ("<u>" if col.is_key else "")
assert col_parsed.group("name") == col.name
assert col_parsed.group("key_closing") == ("</u>" if col.is_key else "")
assert col_parsed.group("key_closing") == ("</u> " if col.is_key else " ")
assert col_parsed.group("type") == col.type


Expand All @@ -92,14 +87,16 @@ def test_column_is_dot_format():

def test_relation():
relation_re = re.compile(
'\\"(?P<l_name>.+)\\"\\ \\-\\-\\ \\"(?P<r_name>.+)\\"\\ '
"\\[taillabel\\=\\<\\<FONT\\>(?P<l_card>.+)\\<\\/FONT\\>\\>"
"\\,headlabel\\=\\<\\<FONT\\>(?P<r_card>.+)\\<\\/FONT\\>\\>\\]\\;",
r"\"(?P<l_table>.+)\":\"(?P<l_column>.+)\"\ \-\-\ \"(?P<r_table>.+)\":\"(?P<r_column>.+)\"\ "
r"\[taillabel\=\<\<FONT\>(?P<l_card>.+)\<\/FONT\>\>"
r"\,headlabel\=\<\<FONT\>(?P<r_card>.+)\<\/FONT\>\>\]\;"
)
dot = relation.to_dot()
r = relation_re.match(dot)
assert r.group("l_name") == "child"
assert r.group("r_name") == "parent"
assert r.group("l_table") == "child"
assert r.group("l_column") == "parent_id"
assert r.group("r_table") == "parent"
assert r.group("r_column") == "id"
assert r.group("l_card") == "0..N"
assert r.group("r_card") == "{0,1}"

Expand Down
5 changes: 4 additions & 1 deletion tests/test_intermediary_to_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def test_column_to_er():


def test_relation():
assert relation.to_markdown() in ["parent ?--* child", "child *--? parent"]
assert relation.to_markdown() in [
'parent."id" *--? child."parent_id"',
'child."parent_id" *--? parent."id"',
]


def assert_table_well_rendered_to_er(table):
Expand Down
8 changes: 5 additions & 3 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def test_parse_line():
assert isinstance(rv, Column)

for s in relations_lst:
rv = parse_line(s)
assert rv.right_col == s[16:].strip()
assert rv.left_col == s[:12].strip()
rv = parse_line(s) # type: Relation
assert rv.right_table == s[16:].strip()
assert rv.right_column == ""
assert rv.left_table == s[:12].strip()
assert rv.left_column == ""
assert rv.right_cardinality == s[15]
assert rv.left_cardinality == s[12]
assert isinstance(rv, Relation)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_sqla_to_intermediary.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def test_table_names_in_relationships(pg_db_uri):
table_names = [t.name for t in tables]

# Assert column names are table names
assert all(r.right_col in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships)
assert all(r.right_table in table_names for r in relationships)
assert all(r.left_table in table_names for r in relationships)

# Assert column names match table names
for r in relationships:
r_name = table_names[table_names.index(r.right_col)]
l_name = table_names[table_names.index(r.left_col)]
r_name = table_names[table_names.index(r.right_table)]
l_name = table_names[table_names.index(r.left_table)]

# Table name in relationship should *NOT* have a schema
assert r_name.find(".") == -1
Expand All @@ -142,8 +142,8 @@ def test_table_names_in_relationships_with_schema(pg_db_uri):
table_names = [t.name for t in tables]

# Assert column names match table names, including schema
assert all(r.right_col in table_names for r in relationships)
assert all(r.left_col in table_names for r in relationships)
assert all(r.right_table in table_names for r in relationships)
assert all(r.left_table in table_names for r in relationships)

# Assert column names match table names, including schema
for r in relationships:
Expand Down

0 comments on commit 74d7bee

Please sign in to comment.