Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
Test import on CI (#53)
Browse files Browse the repository at this point in the history
* test import on ci

* conditionally add HaikuModule to __all__

* expose Treex

* expose Filters

* remove elegy import test
  • Loading branch information
cgarciae authored Dec 18, 2021
1 parent 5a7ab57 commit 80fa406
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 38 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,34 @@ jobs:

- name: Test Examples
run: bash scripts/test-examples.sh

test-import:
name: Test Import without Dev Dependencies
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
runs-on: ubuntu-latest
strategy:
matrix:
# python-version: [3.9]
python-version: [3.7, 3.8, 3.9]
steps:
- name: Check out the code
uses: actions/checkout@v2
with:
fetch-depth: 1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
uses: snok/install-poetry@v1.1.1
with:
version: 1.1.4

- name: Install Dependencies
run: |
poetry config virtualenvs.create false
poetry install --no-dev
- name: Test Import Treex
run: python -c "import treex"
74 changes: 38 additions & 36 deletions treex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from treex.nn import *
from treex.optimizer import Optimizer
from treex.treex import Filters, Treex
from treex.types import (
BatchStat,
Cache,
Expand All @@ -38,39 +39,40 @@

from . import losses, metrics, nn, regularizers

__all__ = (
treeo.__all__
+ nn.__all__
+ [
"KeySeq",
"Loss",
"LossAndLogs",
"Metric",
"Module",
"ModuleMeta",
"compact_module",
"preserve_state",
"next_key",
"rng_key",
"Optimizer",
"BatchStat",
"Cache",
"Log",
"LossLog",
"MetricLog",
"MetricState",
"ModelState",
"OptState",
"Parameter",
"Rng",
"State",
"TreePart",
"Initializer",
"Inputs",
"Named",
"losses",
"metrics",
"nn",
"regularizers",
]
)
__all__ = [
"KeySeq",
"Loss",
"LossAndLogs",
"Metric",
"Module",
"ModuleMeta",
"compact_module",
"preserve_state",
"next_key",
"rng_key",
"Optimizer",
"Treex",
"Filters",
"BatchStat",
"Cache",
"Log",
"LossLog",
"MetricLog",
"MetricState",
"ModelState",
"OptState",
"Parameter",
"Rng",
"State",
"TreePart",
"Initializer",
"Inputs",
"Named",
"losses",
"metrics",
"nn",
"regularizers",
]

__all__.extend(treeo.__all__)
__all__.extend(nn.__all__)
8 changes: 6 additions & 2 deletions treex/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

try:
from .haiku_module import HaikuModule

_haiku_available = True
except types.OptionalDependencyNotFound:
pass
_haiku_available = False

__all__ = [
"BatchNorm",
Expand All @@ -25,5 +27,7 @@
"Lambda",
"Sequential",
"sequence",
"HaikuModule",
]

if _haiku_available:
__all__.append("HaikuModule")

0 comments on commit 80fa406

Please sign in to comment.