Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/wyf3/llm_related
Browse files Browse the repository at this point in the history
  • Loading branch information
wyf3 committed Dec 13, 2024
2 parents c453cb4 + f4c2b38 commit 06c1d79
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions train_multimodal_from_scratch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def forward(self, input_ids, labels, pixel_values, attention_mask=None):
def merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):

num_images, num_image_patches, embed_dim = image_features.shape
batch_indices, iamge_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])
batch_indices, image_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])

inputs_embeds[batch_indices, iamge_indices] = image_features.view(-1, embed_dim)
inputs_embeds[batch_indices, image_indices] = image_features.view(-1, embed_dim)

return inputs_embeds

Expand Down
6 changes: 3 additions & 3 deletions train_multimodal_from_scratch/trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -134,8 +134,8 @@
" \n",
" def merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):\n",
" num_images, num_image_patches, embed_dim = image_features.shape\n",
" batch_indices, iamge_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])\n",
" inputs_embeds[batch_indices, iamge_indices] = image_features.view(-1, embed_dim)\n",
" batch_indices, image_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])\n",
" inputs_embeds[batch_indices, image_indices] = image_features.view(-1, embed_dim)\n",
" return inputs_embeds"
]
},
Expand Down

0 comments on commit 06c1d79

Please sign in to comment.