Skip to content

Commit

Permalink
Add Module plugins API
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 committed Oct 27, 2022
1 parent 6764cbf commit 863dfcd
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 50 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ add_library(libsrf
src/public/core/logging.cpp
src/public/cuda/device_guard.cpp
src/public/cuda/sync.cpp
src/public/experimental/modules/module_registry.cpp
src/public/experimental/modules/plugins.cpp
src/public/experimental/modules/sample_modules.cpp
src/public/experimental/modules/segment_module_registry.cpp
src/public/experimental/modules/segment_modules.cpp
src/public/manifold/manifold.cpp
src/public/memory/buffer_view.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

#pragma once

#include "srf/experimental/modules/segment_module_registry.hpp"
#include "srf/experimental/modules/module_registry.hpp"
#include "srf/experimental/modules/segment_modules.hpp"

#include <dlfcn.h>
#include <nlohmann/json.hpp>

namespace srf::modules {
Expand All @@ -33,7 +34,7 @@ struct ModelRegistryUtil
* @param registry_namespace Namespace where `name` should be registered.
*/
template <typename ModuleTypeT>
static void register_module(std::string name, std::string registry_namespace, const std::vector<unsigned int>& release_version)
static void create_registered_module(std::string name, std::string registry_namespace, const std::vector<unsigned int>& release_version)
{
static_assert(std::is_base_of_v<modules::SegmentModule, ModuleTypeT>);

Expand Down
96 changes: 96 additions & 0 deletions include/srf/experimental/modules/plugins.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <dlfcn.h>
#include <glog/logging.h>

#include <map>
#include <memory>
#include <mutex>
#include <sstream>

namespace srf::modules {

class ModulePluginLibrary
{
using module_plugin_map_t = std::map<std::string, std::mutex>;

public:
ModulePluginLibrary(ModulePluginLibrary&&) = delete;
ModulePluginLibrary(const ModulePluginLibrary&) = delete;

~ModulePluginLibrary() = default;

void operator=(const ModulePluginLibrary&) = delete;

static std::unique_ptr<ModulePluginLibrary> Acquire(std::unique_ptr<ModulePluginLibrary> uptr_plugin,
std::string plugin_library_path);

// Configuration so that dependent libraries will be searched for in
// 'path' during OpenLibraryHandle.
void set_library_directory(const std::string& path);

// Reset any configuration done by SetLibraryDirectory.
void reset_library_directory();

/**
* Load plugin module -- will load the plugin library and call its loader entrypoint to register
* any modules it contains.
*/
void load();

/**
* Unload the plugin module -- this will call the unload entrypoint of the plugin, which will then
* unload any registered models.
*/
void unload();


/**
* Return a list of modules published by the plugin
*/
unsigned int list_modules(const char** list);

private:
explicit ModulePluginLibrary() = delete;
explicit ModulePluginLibrary(std::string plugin_library_path) :
m_plugin_library_path(std::move(plugin_library_path))
{}

static std::mutex s_mutex;
static module_plugin_map_t s_plugin_map;

static const std::string PluginEntrypointLoad;
static const std::string PluginEntrypointUnload;
static const std::string PluginEntrypointList;

void* m_plugin_handle{nullptr};

bool m_loaded{false};
std::string m_plugin_library_path{};

bool (*m_plugin_load)();
bool (*m_plugin_unload)();
unsigned int (*m_plugin_list)(const char**);

void open_library_handle();
void get_plugin_entrypoint(const std::string& entrypoint_name, void** entrypoint);
};

} // namespace srf::modules
2 changes: 1 addition & 1 deletion include/srf/segment/builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "srf/core/watcher.hpp"
#include "srf/engine/segment/ibuilder.hpp"
#include "srf/exceptions/runtime_error.hpp"
#include "srf/experimental/modules/segment_module_registry.hpp"
#include "srf/experimental/modules/module_registry.hpp"
#include "srf/experimental/modules/segment_modules.hpp"
#include "srf/node/edge_builder.hpp"
#include "srf/node/rx_node.hpp"
Expand Down
16 changes: 8 additions & 8 deletions python/srf/core/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "pysrf/types.hpp"
#include "pysrf/utils.hpp"

#include "srf/experimental/modules/module_registry_util.hpp"
#include "srf/experimental/modules/sample_modules.hpp"
#include "srf/experimental/modules/segment_module_util.hpp"
#include "srf/experimental/modules/segment_modules.hpp"
#include "srf/node/edge_connector.hpp"
#include "srf/segment/builder.hpp"
Expand Down Expand Up @@ -176,19 +176,19 @@ PYBIND11_MODULE(segment, m)
Builder.def("make_py2cxx_edge_adapter", &BuilderProxy::make_py2cxx_edge_adapter);

/** Register test modules -- necessary for python unit tests**/
modules::ModelRegistryUtil::register_module<srf::modules::SimpleModule>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::SimpleModule>(
"SimpleModule", "srf_unittest", PybindSegmentModuleVersion);
modules::ModelRegistryUtil::register_module<srf::modules::ConfigurableModule>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::ConfigurableModule>(
"ConfigurableModule", "srf_unittest", PybindSegmentModuleVersion);
modules::ModelRegistryUtil::register_module<srf::modules::SourceModule>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::SourceModule>(
"SourceModule", "srf_unittest", PybindSegmentModuleVersion);
modules::ModelRegistryUtil::register_module<srf::modules::SinkModule>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::SinkModule>(
"SinkModule", "srf_unittest", PybindSegmentModuleVersion);
modules::ModelRegistryUtil::register_module<srf::modules::NestedModule>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::NestedModule>(
"NestedModule", "srf_unittest", PybindSegmentModuleVersion);
modules::ModelRegistryUtil::register_module<srf::modules::TemplateModule<int>>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::TemplateModule<int>>(
"TemplateModuleInt", "srf_unittest", PybindSegmentModuleVersion);
modules::ModelRegistryUtil::register_module<srf::modules::TemplateModule<std::string>>(
modules::ModelRegistryUtil::create_registered_module<srf::modules::TemplateModule<std::string>>(
"TemplateModuleString", "srf_unittest", PybindSegmentModuleVersion);

/** Segment Module Interface Declarations **/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
* limitations under the License.
*/

#include "srf/experimental/modules/segment_module_registry.hpp"

#include "srf/experimental/modules/module_registry.hpp"
#include "srf/experimental/modules/segment_modules.hpp"

#include <algorithm>
Expand Down Expand Up @@ -101,8 +100,6 @@ void ModuleRegistry::register_module(std::string name,
srf::modules::ModuleRegistry::module_constructor_t fn_constructor)
{
std::lock_guard<decltype(s_mutex)> lock(s_mutex);
// TODO(devin) : Reject modules that are not equal to the current build -- we will need to decide on
// better criteria going forward.
VLOG(2) << "Registering module: " << registry_namespace << "::" << name;
if (!is_version_compatible(release_version))
{
Expand Down Expand Up @@ -146,6 +143,8 @@ void ModuleRegistry::unregister_module(const std::string& name, const std::strin
{
std::lock_guard<decltype(s_mutex)> lock(s_mutex);

VLOG(2) << "Unregistering module " << registry_namespace << "::" << name;

if (contains(name, registry_namespace))
{
s_module_namespace_registry[registry_namespace].erase(name);
Expand All @@ -155,6 +154,12 @@ void ModuleRegistry::unregister_module(const std::string& name, const std::strin

name_map.erase(iter_erase);

if (s_module_namespace_registry[registry_namespace].empty()) {
VLOG(2) << "Namespace " << registry_namespace << " is empty, removing.";
s_module_namespace_registry.erase(registry_namespace);
s_module_name_map.erase(registry_namespace);
}

return;
}

Expand All @@ -172,6 +177,7 @@ void ModuleRegistry::unregister_module(const std::string& name, const std::strin

bool ModuleRegistry::is_version_compatible(const std::vector<unsigned int>& release_version)
{
// TODO(devin) improve criteria for module compatibility
return std::equal(ModuleRegistry::Version.begin(),
ModuleRegistry::Version.begin() + ModuleRegistry::VersionElements,
release_version.begin());
Expand Down
134 changes: 134 additions & 0 deletions src/public/experimental/modules/plugins.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "srf/experimental/modules/plugins.hpp"

#include <iostream>
#include <memory>
#include <mutex>
#include <string>

namespace srf::modules {

std::mutex ModulePluginLibrary::s_mutex{};

const std::string ModulePluginLibrary::PluginEntrypointList{"SRF_MODULE_entrypoint_list"};
const std::string ModulePluginLibrary::PluginEntrypointLoad{"SRF_MODULE_entrypoint_load"};
const std::string ModulePluginLibrary::PluginEntrypointUnload{"SRF_MODULE_entrypoint_unload"};

std::unique_ptr<ModulePluginLibrary> ModulePluginLibrary::Acquire(std::unique_ptr<ModulePluginLibrary> uptr_plugin,
std::string plugin_library_path)
{
std::lock_guard<decltype(s_mutex)> lock(s_mutex);

uptr_plugin.reset(new ModulePluginLibrary(std::move(plugin_library_path)));

return std::move(uptr_plugin);
}

void ModulePluginLibrary::set_library_directory(const std::string& path)
{
throw std::runtime_error("Unimplemented");
}

void ModulePluginLibrary::reset_library_directory()
{
throw std::runtime_error("Unimplemented");
}

unsigned int ModulePluginLibrary::list_modules(const char** list)
{
return m_plugin_list(list);
}

void ModulePluginLibrary::load()
{
if (m_loaded)
{
return;
}

open_library_handle();
get_plugin_entrypoint(PluginEntrypointList, reinterpret_cast<void**>(&m_plugin_list));
get_plugin_entrypoint(PluginEntrypointLoad, reinterpret_cast<void**>(&m_plugin_load));
get_plugin_entrypoint(PluginEntrypointUnload, reinterpret_cast<void**>(&m_plugin_unload));

m_plugin_load();
m_loaded = true;
}

void ModulePluginLibrary::unload()
{
if (!m_loaded)
{
return;
}

m_plugin_unload();

if (dlclose(m_plugin_handle) != 0)
{
std::stringstream sstream;

sstream << "Failed to close plugin module -> " << dlerror();
VLOG(2) << sstream.str();
throw std::runtime_error(dlerror());
}

m_plugin_load = nullptr;
m_plugin_unload = nullptr;

m_loaded = false;
}

void ModulePluginLibrary::open_library_handle()
{
m_plugin_handle = dlopen(m_plugin_library_path.c_str(), RTLD_NOW | RTLD_LOCAL);
if (m_plugin_handle == nullptr)
{
std::stringstream sstream;

sstream << "Failed to open plugin module -> " << dlerror();
VLOG(2) << sstream.str();

throw std::runtime_error(sstream.str());
}
}

void ModulePluginLibrary::get_plugin_entrypoint(const std::string& entrypoint_name, void** entrypoint)
{
*entrypoint = nullptr;

dlerror();
void* _fn = dlsym(m_plugin_handle, entrypoint_name.c_str());

const char* dlsym_error = dlerror();
if (dlsym_error != nullptr)
{
std::stringstream sstream;

sstream << "Failed to find entrypoint -> '" << entrypoint_name << "' in '" << m_plugin_library_path << " : "
<< dlsym_error;

VLOG(2) << sstream.str();
throw std::invalid_argument(sstream.str());
}

*entrypoint = _fn;
}

} // namespace srf::modules
Loading

0 comments on commit 863dfcd

Please sign in to comment.