Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Delete metadata.py #2253

Merged
merged 3 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor out custom columns
  • Loading branch information
mattzh72 committed Dec 14, 2024
commit 4884967ae8683a8758f0a72f823d753433088fbe
103 changes: 8 additions & 95 deletions letta/orm/agent.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import uuid
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional

from sqlalchemy import JSON, String, TypeDecorator, UniqueConstraint
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.block import Block
from letta.orm.custom_columns import (
EmbeddingConfigColumn,
LLMConfigColumn,
ToolRulesColumn,
)
from letta.orm.message import Message
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.tool_rule import (
ChildToolRule,
InitToolRule,
TerminalToolRule,
ToolRule,
)
from letta.schemas.tool_rule import ToolRule

if TYPE_CHECKING:
from letta.orm.agents_tags import AgentsTags
Expand All @@ -29,92 +28,6 @@
from letta.orm.tool import Tool


class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
# return vars(value)
if isinstance(value, LLMConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value


class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
# return vars(value)
if isinstance(value, EmbeddingConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value


class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value

for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value

def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")


class Agent(SqlalchemyBase, OrganizationMixin):
__tablename__ = "agents"
__pydantic_model__ = PydanticAgentState
Expand Down
152 changes: 152 additions & 0 deletions letta/orm/custom_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import base64
from typing import List, Union

import numpy as np
from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator

from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule


class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON."""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value and isinstance(value, EmbeddingConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value


class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON."""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value and isinstance(value, LLMConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value


class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value

for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value

def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")


class ToolCallColumn(TypeDecorator):

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values

return value

def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = ToolCallFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(ToolCall(function=tool_call_function, **tool_value))
return tools
return value


class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""

impl = BINARY
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())

def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
return base64.b64encode(value.tobytes())

def process_result_value(self, value, dialect):
if not value:
return value
if dialect.name == "sqlite":
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)
38 changes: 2 additions & 36 deletions letta/orm/message.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,12 @@
from typing import Optional

from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.custom_columns import ToolCallColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction


class ToolCallColumn(TypeDecorator):

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values

return value

def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = ToolCallFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(ToolCall(function=tool_call_function, **tool_value))
return tools
return value
from letta.schemas.openai.chat_completions import ToolCall


class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
Expand Down
28 changes: 1 addition & 27 deletions letta/orm/passage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import base64
from datetime import datetime
from typing import TYPE_CHECKING, Optional

import numpy as np
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import BINARY, TypeDecorator

from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.orm.custom_columns import CommonVector
from letta.orm.mixins import FileMixin, OrganizationMixin
from letta.orm.source import EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
Expand All @@ -21,30 +19,6 @@
from letta.orm.organization import Organization


class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""

impl = BINARY
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())

def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
return base64.b64encode(value.tobytes())

def process_result_value(self, value, dialect):
if not value:
return value
if dialect.name == "sqlite":
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)


# TODO: After migration to Passage, will need to manually delete passages where files
# are deleted on web
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
Expand Down
Loading