Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete IA3 adapter #1153

Merged
merged 1 commit into from
Nov 20, 2023
Merged

Conversation

alexrs
Copy link
Contributor

@alexrs alexrs commented Nov 20, 2023

What

As discussed in #980 (specifically #980 (comment)), we are breaking that PR into smaller PRs that can be merged faster.

In this first PR, we include the ability to delete $(IA)^3$ adapters.

How

A new method delete_adapter has been added to the IA3Model class. This method is based on

def delete_adapter(self, adapter_name: str):

Test Plan

  • Added PeftType.IA3 to testing_common.py#_test_delete_adapter
  • Run the tests and checked that they pass

Comment on lines -30 to -35
# All names of other parameters that may contain adapter-related parameters
other_layer_names = ("scaling",)

def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
self.base_layer = base_layer
self.scaling = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, we do not use self.scaling and $(IA)^3$ does not have this parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks for cleaning that up.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much Alejandro for factoring out the deletion-feature and creating this super clean PR. Not only do you show great understanding of the PEFT code base, but also improved some unrelated lines on top, perfect!

I have only a few minor comments, they should be easy to fix and then we're good to merge.

Comment on lines -30 to -35
# All names of other parameters that may contain adapter-related parameters
other_layer_names = ("scaling",)

def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
self.base_layer = base_layer
self.scaling = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks for cleaning that up.

src/peft/tuners/ia3/layer.py Show resolved Hide resolved
src/peft/tuners/ia3/model.py Show resolved Hide resolved
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if adapter_name not in list(self.peft_config.keys()):
if adapter_name not in self.peft_config:

This should be more efficient.

src/peft/tuners/ia3/model.py Show resolved Hide resolved
@@ -905,7 +905,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good amendment, but this made me curious why the code worked at all, because "lora" was hard-coded but we also test LoHa and LoKr. It turns out the that prefix check is totally unnecessary and the code can be simplified to:

key_list = [key for key, _ in model.named_modules()]

Could you please fix that? Same for _test_delete_inactive_adapter and _test_weighted_combination_of_adapters. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also curious about why this worked at all. I'll fix it!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@alexrs
Copy link
Contributor Author

alexrs commented Nov 20, 2023

Not only do you show great understanding of the PEFT code base, but also improved some unrelated lines on top, perfect!

Thanks! I've spent quite a few hours navigating the codebase for my MSc thesis, and it has helped me a lot. Thanks for the good work! 👏

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic work, thanks a lot!

@BenjaminBossan BenjaminBossan merged commit 8351331 into huggingface:main Nov 20, 2023
@alexrs alexrs deleted the remove-ia3-adapter branch November 20, 2023 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants