Skip to content

Commit

Permalink
Replaced tensorflow for numpy in type tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Martí Zamora Casals committed Jan 7, 2025
1 parent c7eab28 commit e570f1b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
additional_dependencies:
[
beartype,
tensorflow,
numpy<2,
]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.1
Expand All @@ -48,6 +48,6 @@ repos:
additional_dependencies:
[
beartype,
tensorflow,
numpy<2,
]
args: ["--ignore-missing-imports", "--follow-imports=skip"]
10 changes: 5 additions & 5 deletions test/types/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

import tensorflow as tf
import numpy as np
from beartype import beartype

from jaxtyping import Float, Int, jaxtyped
Expand All @@ -11,8 +11,8 @@
class User:
name: str
age: int
items: Float[tf.Tensor, " N"]
timestamps: Int[tf.Tensor, " N"]
items: Float[np.ndarray, " N"]
timestamps: Int[np.ndarray, " N"]


@jaxtyped(typechecker=beartype)
Expand All @@ -24,8 +24,8 @@ def transform_user(user: User, increment_age: int = 1) -> User:
user = User(
name="John",
age=20,
items=tf.random.normal([10]),
timestamps=tf.random.uniform([10], minval=0, maxval=100, dtype=tf.int32),
items=np.random.normal(size=10),
timestamps=np.random.randint(0, 100, size=10),
)

new_user = transform_user(user, increment_age=2)

0 comments on commit e570f1b

Please sign in to comment.