-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce in-memory resource abstraction #375
Introduce in-memory resource abstraction #375
Conversation
This follows from discussion in guillaume-be#366. The goal of this change is to allow for weights to be loaded from a copy of `rust_model.ot` that is already present in memory. There are two ways in which that data might be present: 1. As a `HashMap<String, Tensor>` from previous interaction with `tch` 2. As a contiguous buffer of the file data One or the other mechanism might be preferable depending on how user code is using the model data. In some sense, implementing a provider based on the second option is more of a convenience method for the user to avoid the `tch::nn::VarStore::load_from_stream` interaction. I've changed the definition of the `ResourceProvider` trait to require that it be both `Send` and `Sync`. There are currently certain contexts where `dyn ResourceProvider + Send` is required, but in theory before this change an implementation might not be `Send` (or `Sync`). The existing providers are both `Send` and `Sync`, and it seems reasonable (if technically incorrect) for user code to assume this to be true. I don't see a downside to making this explicit, but that part of this change might be better suited for separate discussion. I am not trying to sneak it in. The `enum Resource` data type is used here as a means to abstract over the possible ways a `ResourceProvider` might represent an underlying resource. Without this, it would be necessary to either call different trait methods until one succeeded or implement `as_any` and downcast in order to implement `load_weights` similarly to how it is now. Those options seemed less preferable to creating a wrapper. While it would be possible to replace all calls to `get_local_path` with the `get_resource` API, removal of the existing function would be a very big breaking change. As such, this change also introduces `RustBertError::UnsupportedError` to allow for the different methods to coexist. An alternative would be for the new `ResourceProvider`s to write their resources to a temporary disk location and return an appropriate path, but that is counter to the purpose of the new `ResourceProvider`s and so I chose not to implement that.
After thinking more about the potential for Dropping the requirement I've added here for |
- Remove `Resource::NamedTensors` - Change `BufferResource` to contain a `&[u8]` rather than `Vec<u8>`
Minor changes to in-memory resources implementation
Thank you @mweber15 . Could you please add a small test that would validate and illustrate the approach is working for the buffer resource? You could maybe pick any of the small-sized model (e.g. DistilBART as done so far or a DistilBERT model) |
Hi @guillaume-be, The recent build change has broken the build for me in a couple of ways that I'm working through. Once I have that sorted I will add an illustrative test and make sure the doc comments reflect the current state of the PR. Hopefully soon. |
@@ -20,14 +20,13 @@ fn main() -> anyhow::Result<()> { | |||
let merges_resource = Box::new(RemoteResource::from_pretrained( | |||
DebertaMergesResources::DEBERTA_BASE_MNLI, | |||
)); | |||
let model_resource = Box::new(RemoteResource::from_pretrained( | |||
let mut model_resource = Box::new(RemoteResource::from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that with the latest changes (RwLocks
) there is no need for the resource to be mutable anymore meaning the changes should be essentially backward compatible (examples/tests should be possible to leave unchanged).
The signature of load_weights
becomes:
pub fn load_weights(
rp: &(impl ResourceProvider + ?Sized),
vs: &mut VarStore,
) -> Result<(), RustBertError> {
match rp.get_resource()? {
Resource::Buffer(mut data) => {
vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?;
Ok(())
}
Resource::PathBuf(path) => Ok(vs.load(path)?),
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Thank you so much for your help putting this together! |
This follows from discussion in #366.
The goal of this change is to allow for weights to be loaded from a copy of
rust_model.ot
that is already present in memory. There are two ways in which that data might be present:HashMap<String, Tensor>
from previous interaction withtch
One or the other mechanism might be preferable depending on how user code is using the model data. In some sense, implementing a provider based on the second option is more of a convenience method for the user to avoid the
tch::nn::VarStore::load_from_stream
interaction.I've changed the definition of the
ResourceProvider
trait to require that it be bothSend
andSync
. There are currently certain contexts wheredyn ResourceProvider + Send
is required, but in theory before this change an implementation might not beSend
(orSync
). The existing providers are bothSend
andSync
, and it seems reasonable (if technically incorrect) for user code to assume this to be true. I don't see a downside to making this explicit, but that part of this change might be better suited for separate discussion. I am not trying to sneak it in.The
enum Resource
data type is used here as a means to abstract over the possible ways aResourceProvider
might represent an underlying resource. Without this, it would be necessary to either call different trait methods until one succeeded or implementas_any
and downcast in order to implementload_weights
similarly to how it is now. Those options seemed less preferable to creating a wrapper.While it would be possible to replace all calls to
get_local_path
with theget_resource
API, removal of the existing function would be a very big breaking change. As such, this change also introducesRustBertError::UnsupportedError
to allow for the different methods to coexist. An alternative would be for the newResourceProvider
s to write their resources to a temporary disk location and return an appropriate path, but that is counter to the purpose of the newResourceProvider
s and so I chose not to implement that.