Skip to content

Commit

Permalink
Update test_nemoguardrails.py
Browse files Browse the repository at this point in the history
  • Loading branch information
henchaves committed Dec 16, 2024
1 parent 3f3d121 commit d090239
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions tests/integrations/test_nemoguardrails.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
import pytest
from pathlib import Path
import tempfile
from unittest.mock import Mock, patch

import pandas as pd
Expand All @@ -9,14 +12,28 @@
from giskard.scanner.report import ScanReport


def _generate_rails(report: ScanReport, filename=None, colang_version="1.0"):
if filename:
with tempfile.TemporaryDirectory() as tmpdir:
dest = Path(tmpdir).joinpath("rails.co")
report.generate_rails(filename=dest, colang_version=colang_version)
assert dest.exists()
assert dest.is_file()
rails = dest.read_text(encoding="utf-8")
else:
rails = report.generate_rails(colang_version=colang_version)
return rails


@pytest.mark.parametrize("filename", [(None), ("rails.co")])
@patch("giskard.integrations.nemoguardrails.get_default_client")
def test_generate_colang_v1_rails_from_scan(get_default_client_mock):
def test_generate_colang_v1_rails_from_scan(get_default_client_mock, filename):
report = make_test_report()

llm_client = get_default_client_mock()
llm_client.complete.side_effect = make_llm_answers()

rails = report.generate_rails()
rails = _generate_rails(report, filename=filename, colang_version="1.0")

# Check that the file is correctly formatted
parsed = parse_colang_file("rails.co", rails, version="1.0")
Expand All @@ -27,14 +44,15 @@ def test_generate_colang_v1_rails_from_scan(get_default_client_mock):
assert parsed["flows"][1]["id"] == "ask help on illegal activities"


@pytest.mark.parametrize("filename", [(None), ("rails.co")])
@patch("giskard.integrations.nemoguardrails.get_default_client")
def test_generate_colang_v2_rails_from_scan(get_default_client_mock):
def test_generate_colang_v2_rails_from_scan(get_default_client_mock, filename):
report = make_test_report()

llm_client = get_default_client_mock()
llm_client.complete.side_effect = make_llm_answers()

rails = report.generate_rails(colang_version="2.x")
rails = _generate_rails(report, filename=filename, colang_version="2.x")

# Check that the file is correctly formatted
parsed = parse_colang_file("rails.co", rails, version="2.x")
Expand Down

0 comments on commit d090239

Please sign in to comment.