Skip to content

Commit

Permalink
Merge pull request #442 from ChrisCummins/bitcode
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins authored Sep 30, 2021
2 parents 9734103 + b2dc549 commit b24a21e
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 16 deletions.
7 changes: 7 additions & 0 deletions compiler_gym/envs/llvm/service/Observation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ Status setObservation(LlvmObservationSpace space, const fs::path& workingDirecto
reply.set_string_value(ss.str());
break;
}
case LlvmObservationSpace::BITCODE: {
std::string bitcode;
llvm::raw_string_ostream outbuffer(bitcode);
llvm::WriteBitcodeToFile(benchmark.module(), outbuffer);
reply.set_binary_value(outbuffer.str());
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
// Generate an output path with 16 bits of randomness.
const auto outpath = fs::unique_path(workingDirectory / "module-%%%%%%%%.bc");
Expand Down
21 changes: 9 additions & 12 deletions compiler_gym/envs/llvm/service/ObservationSpaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,25 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
space.set_name(util::enumNameToPascalCase<LlvmObservationSpace>(value));
switch (value) {
case LlvmObservationSpace::IR: {
ScalarRange irSize;
space.mutable_string_size_range()->mutable_min()->set_value(0);
space.set_deterministic(true);
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::IR_SHA1: {
ScalarRange sha1Size;
space.mutable_string_size_range()->mutable_min()->set_value(40);
space.mutable_string_size_range()->mutable_max()->set_value(40);
space.set_deterministic(true);
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::BITCODE: {
space.mutable_binary_size_range()->mutable_min()->set_value(0);
space.set_deterministic(true);
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
ScalarRange pathLength;
space.mutable_string_size_range()->mutable_min()->set_value(0);
// 4096 is the maximum path length for most filesystems.
space.mutable_string_size_range()->mutable_max()->set_value(kMaximumPathLength);
Expand Down Expand Up @@ -89,10 +92,8 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
}
case LlvmObservationSpace::PROGRAML: {
// ProGraML serializes the graph to JSON.
ScalarRange encodedSize;
encodedSize.mutable_min()->set_value(0);
space.set_opaque_data_format("json://networkx/MultiDiGraph");
*space.mutable_string_size_range() = encodedSize;
space.mutable_string_size_range()->mutable_min()->set_value(0);
space.set_deterministic(true);
space.set_platform_dependent(false);
programl::ProgramGraph graph;
Expand All @@ -104,10 +105,8 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
}
case LlvmObservationSpace::PROGRAML_JSON: {
// ProGraML serializes the graph to JSON.
ScalarRange encodedSize;
encodedSize.mutable_min()->set_value(0);
space.set_opaque_data_format("json://");
*space.mutable_string_size_range() = encodedSize;
space.mutable_string_size_range()->mutable_min()->set_value(0);
space.set_deterministic(true);
space.set_platform_dependent(false);
programl::ProgramGraph graph;
Expand All @@ -119,10 +118,8 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
}
case LlvmObservationSpace::CPU_INFO: {
// Hardware info is returned as a JSON
ScalarRange encodedSize;
encodedSize.mutable_min()->set_value(0);
space.set_opaque_data_format("json://");
*space.mutable_string_size_range() = encodedSize;
space.mutable_string_size_range()->mutable_min()->set_value(0);
space.set_deterministic(true);
space.set_platform_dependent(true);
*space.mutable_default_value()->mutable_string_value() = "{}";
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/envs/llvm/service/ObservationSpaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ enum class LlvmObservationSpace {
IR,
/** The 40-digit hex SHA1 checksum of the LLVM module. */
IR_SHA1,
/** Get the bitcode as a bytes array. */
BITCODE,
/** Write the bitcode to a file and return its path as a string. */
BITCODE_FILE,
/** The counts of all instructions in a program. */
Expand Down
13 changes: 11 additions & 2 deletions compiler_gym/spaces/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,18 @@ def contains(self, x):
upper_bound = float("inf") if self.size_range[1] is None else self.size_range[1]
if not (lower_bound <= len(x) <= upper_bound):
return False
for element in x:
if not isinstance(element, self.dtype):

# TODO(cummins): The dtype API is inconsistent. When dtype=str or
# dtype=bytes, we expect this to be the type of the entire sequence. But
# for dtype=int, we expect this to be the type of each element. We
# should distinguish these differences better.
if self.dtype in {str, bytes}:
if not isinstance(x, self.dtype):
return False
else:
for element in x:
if not isinstance(element, self.dtype):
return False

# Run the bounds check on every scalar element, if there is a scalar
# range specified.
Expand Down
40 changes: 38 additions & 2 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_observation_spaces(env: LlvmEnv):
assert set(env.observation.spaces.keys()) == {
"Ir",
"IrSha1",
"Bitcode",
"BitcodeFile",
"InstCount",
"InstCountDict",
Expand Down Expand Up @@ -105,12 +106,31 @@ def test_ir_sha1_observation_space(env: LlvmEnv):


def test_bitcode_observation_space(env: LlvmEnv):
env.reset("cbench-v1/crc32")
key = "Bitcode"
space = env.observation.spaces[key]
assert isinstance(space.space, Sequence)
assert space.space.dtype == bytes
assert space.space.size_range == (0, None)

assert space.deterministic
assert not space.platform_dependent

value: str = env.observation[key]
print(value) # For debugging in case of error.
assert isinstance(value, bytes)
assert space.space.contains(value)


def test_bitcode_file_observation_space(env: LlvmEnv):
env.reset("cbench-v1/crc32")
key = "BitcodeFile"
space = env.observation.spaces[key]
assert isinstance(space.space, Sequence)
assert space.space.dtype == str
assert space.space.size_range == (0, 4096)
assert not space.deterministic
assert not space.platform_dependent

value: str = env.observation[key]
print(value) # For debugging in case of error.
Expand All @@ -121,8 +141,24 @@ def test_bitcode_observation_space(env: LlvmEnv):
finally:
os.unlink(value)

assert not space.deterministic
assert not space.platform_dependent

@pytest.mark.parametrize(
"benchmark_uri", ["cbench-v1/crc32", "cbench-v1/qsort", "cbench-v1/gsm"]
)
def test_bitcode_file_equivalence(env: LlvmEnv, benchmark_uri: str):
"""Test that LLVM produces the same bitcode as a file and as a byte array."""
env.reset(benchmark=benchmark_uri)

bitcode = env.observation.Bitcode()
bitcode_file = env.observation.BitcodeFile()

try:
with open(bitcode_file, "rb") as f:
bitcode_from_file = f.read()

assert bitcode == bitcode_from_file
finally:
os.unlink(bitcode_file)


# The Autophase feature vector for benchmark://cbench-v1/crc32 in its initial
Expand Down
7 changes: 7 additions & 0 deletions tests/spaces/sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,12 @@ def test_contains_with_float_scalar_range():
assert not space.contains([0.0, 0]) # wrong shape


def test_bytes_contains():
space = Sequence(size_range=(0, None), dtype=bytes)
assert space.contains(b"Hello, world!")
assert space.contains(b"")
assert not space.contains("Hello, world!")


if __name__ == "__main__":
main()

0 comments on commit b24a21e

Please sign in to comment.