Description
RuntimeError Traceback (most recent call last)
Cell In[9], line 4
1 from finetune_visualglm import FineTuneVisualGLMModel
2 import argparse
----> 4 model, args = FineTuneVisualGLMModel.from_pretrained('/kaggle/working/checkpoints/finetune-visualglm-6b-04-09-09-10',
5 args=argparse.Namespace(
6 fp16=True,
7 skip_init=True,
8 use_gpu_initialization=True,
9 device='cuda',
10 ))
11 model.get_mixin('lora').merge_lora()
12 args.layer_range = []
File /opt/conda/lib/python3.10/site-packages/sat/model/base_model.py:207, in BaseModel.from_pretrained(cls, name, args, home_path, url, prefix, build_only, overwrite_args, **kwargs)
205 model = get_model(args, cls, **kwargs)
206 if not build_only:
--> 207 load_checkpoint(model, args, load_path=model_path, prefix=prefix)
208 return model, args
File /opt/conda/lib/python3.10/site-packages/sat/training/model_io.py:238, in load_checkpoint(model, args, load_path, prefix)
235 module = model
237 # only load module, other hyperparameters are just for recording.
--> 238 missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
239 if len(unexpected_keys) > 0:
240 print_rank0(
241 f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2138, in Module.load_state_dict(self, state_dict, strict, assign)
2131 out = hook(module, incompatible_keys)
2132 assert out is None, (
2133 "Hooks registered with register_load_state_dict_post_hook
are not"
2134 "expected to return new values, if incompatible_keys need to be modified,"
2135 "it should be done inplace."
2136 )
-> 2138 load(self, state_dict)
2139 del load
2141 if strict:
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2126 (3 times)]
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2120, in Module.load_state_dict..load(module, local_state_dict, prefix)
2118 if assign:
2119 local_metadata['assign_to_params_buffers'] = assign
-> 2120 module._load_from_state_dict(
2121 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2122 for name, child in module._modules.items():
2123 if child is not None:
File /opt/conda/lib/python3.10/site-packages/sat/model/finetune/lora2.py:47, in HackLinearNF4._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
45 def load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
46 if prefix + 'weight' in state_dict:
---> 47 self.weight.data.copy(state_dict[prefix+'weight'])
48 if self.weight.data.dtype == torch.uint8:
49 copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)
RuntimeError: output with shape [25165824, 1] doesn't match the broadcast shape [25165824, 0]
How can I solve this problem?
Activity