Skip to content

qlora merge lora weights error  #350

Open
@zousss

Description

RuntimeError Traceback (most recent call last)
Cell In[9], line 4
1 from finetune_visualglm import FineTuneVisualGLMModel
2 import argparse
----> 4 model, args = FineTuneVisualGLMModel.from_pretrained('/kaggle/working/checkpoints/finetune-visualglm-6b-04-09-09-10',
5 args=argparse.Namespace(
6 fp16=True,
7 skip_init=True,
8 use_gpu_initialization=True,
9 device='cuda',
10 ))
11 model.get_mixin('lora').merge_lora()
12 args.layer_range = []

File /opt/conda/lib/python3.10/site-packages/sat/model/base_model.py:207, in BaseModel.from_pretrained(cls, name, args, home_path, url, prefix, build_only, overwrite_args, **kwargs)
205 model = get_model(args, cls, **kwargs)
206 if not build_only:
--> 207 load_checkpoint(model, args, load_path=model_path, prefix=prefix)
208 return model, args

File /opt/conda/lib/python3.10/site-packages/sat/training/model_io.py:238, in load_checkpoint(model, args, load_path, prefix)
235 module = model
237 # only load module, other hyperparameters are just for recording.
--> 238 missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
239 if len(unexpected_keys) > 0:
240 print_rank0(
241 f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2138, in Module.load_state_dict(self, state_dict, strict, assign)
2131 out = hook(module, incompatible_keys)
2132 assert out is None, (
2133 "Hooks registered with register_load_state_dict_post_hook are not"
2134 "expected to return new values, if incompatible_keys need to be modified,"
2135 "it should be done inplace."
2136 )
-> 2138 load(self, state_dict)
2139 del load
2141 if strict:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2126 (3 times)]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2120, in Module.load_state_dict..load(module, local_state_dict, prefix)
2118 if assign:
2119 local_metadata['assign_to_params_buffers'] = assign
-> 2120 module._load_from_state_dict(
2121 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2122 for name, child in module._modules.items():
2123 if child is not None:

File /opt/conda/lib/python3.10/site-packages/sat/model/finetune/lora2.py:47, in HackLinearNF4._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
45 def load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
46 if prefix + 'weight' in state_dict:
---> 47 self.weight.data.copy
(state_dict[prefix+'weight'])
48 if self.weight.data.dtype == torch.uint8:
49 copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)

RuntimeError: output with shape [25165824, 1] doesn't match the broadcast shape [25165824, 0]

How can I solve this problem?

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      qlora merge lora weights error · Issue #350 · THUDM/VisualGLM-6B