Skip to content

Commit

Permalink
+ new icon
Browse files Browse the repository at this point in the history
+ decouple module
+ support LSTUR again
  • Loading branch information
Jyonn committed Dec 4, 2024
1 parent 9ea891d commit ea27331
Show file tree
Hide file tree
Showing 18 changed files with 172 additions and 120 deletions.
Binary file removed Legommenders.png
Binary file not shown.
Binary file added assets/lego.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions config/exp/llama-split.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ dir: saving/${data.name}/${model.name}/${embed.name}-${exp.name}
log: ${exp.dir}/exp.log
mode: test_llm_layer_split
store:
layers: [31, 30, 29, 27] # 7b
layers: [31, 30] # 7b
# layers: [39, 38, 37, 35] # 13b
dir: data/${data.name}/llama-${llm_ver}-split
dir: /home/data4/qijiong/Data/Lego/${data.name}/llama-${llm_ver}-split
load:
save_dir: null
model_only: true
Expand Down
4 changes: 2 additions & 2 deletions config/model/llm/llama-naml.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ config:
use_item_content: true
max_item_content_batch_size: ${max_item_batch_size}$
same_dim_transform: false
embed_hidden_size: ${embed_hidden_size}$
embed_hidden_size: ${embed_hidden_sijze}$
hidden_size: ${hidden_size}$
neg_count: 4
item_config:
llm_dir: /home/data1/qijiong/llama-${llm_ver}
layer_split: ${layer}$
lora: ${lora}$
weights_dir: data/${data.name}/llama-${llm_ver}-split
weights_dir: /home/data4/qijiong/Data/Lego/${data.name}/llama-${llm_ver}-split
user_config:
num_attention_heads: 12
inputer_config:
Expand Down
13 changes: 9 additions & 4 deletions loader/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pigmento import pnt
from torch import nn

from model.meta_config import LegommenderMeta, LegommenderConfig
from loader.data_hubs import DataHubs
from loader.data_sets import DataSets
from loader.depot.depot_hub import DepotHub
Expand All @@ -15,13 +16,14 @@
from model.inputer.concat_inputer import ConcatInputer
from model.inputer.flatten_seq_inputer import FlattenSeqInputer
from model.inputer.natural_concat_inputer import NaturalConcatInputer
from model.legommender import Legommender, LegommenderConfig, LegommenderMeta
from model.legommender import Legommender
from loader.column_map import ColumnMap
from loader.embedding.embedding_hub import EmbeddingHub
from loader.resampler import Resampler
from loader.data_loader import DataLoader
from loader.data_hub import DataHub
from loader.class_hub import ClassHub
from model.preparer import Preparer
from loader.resampler import Resampler


class Controller:
Expand All @@ -34,7 +36,7 @@ def __init__(self, data, embed, model, exp):

self.status = Status()

if 'MIND' in self.data.name.upper():
if 'MIND' in self.data.name.upper() or 'EB-NERD' in self.data.name.upper():
Meta.data_type = DatasetType.news
else:
Meta.data_type = DatasetType.book
Expand Down Expand Up @@ -113,14 +115,17 @@ def __init__(self, data, embed, model, exp):

# legommender initialization
# self.legommender = self.legommender_class(
self.legommender = Legommender(
self.preparer = Preparer(
meta=self.legommender_meta,
status=self.status,
config=self.legommender_config,
column_map=self.column_map,
embedding_manager=self.embedding_hub,
user_hub=self.hubs.a_hub(),
item_hub=self.item_hub,
)
self.legommender = Legommender(
preparer=self.preparer,
user_plugin=user_plugin,
)
self.resampler = Resampler(
Expand Down
5 changes: 3 additions & 2 deletions loader/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from loader.meta import Meta
from loader.status import Status
from model.legommender import Legommender, LegommenderConfig
from model.legommender import Legommender
from loader.data_hub import DataHub
from loader.data_set import DataSet
from utils.stacker import Stacker, FastStacker
from model.meta_config import LegommenderConfig
from utils.stacker import FastStacker
from utils.timer import Timer


Expand Down
3 changes: 0 additions & 3 deletions model/common/fastformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertIntermediate, BertOutput, BertSelfOutput
Expand Down Expand Up @@ -180,7 +178,6 @@ def __init__(self, config: FastformerConfig, pooler_count=1):
if config.pooler_type == 'weightpooler':
for _ in range(pooler_count):
self.poolers.append(AttentionPooling(config))
logging.info(f"This model has {len(self.poolers)} poolers.")

self.apply(self.init_weights)

Expand Down
3 changes: 3 additions & 0 deletions model/inputer/llm_concat_inputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ def get_start_prompt():
def get_col_prompts():
return dict(
newtitle=[529, 3257, 29958],
subtitle=[529, 1491, 3257, 29958],
title=[529, 3257, 29958],
abs=[529, 16595, 29958],
cat=[529, 7320, 29958],
category=[529, 7320, 29958],
subCat=[529, 1491, 7320, 29958],
desc=[529, 16595, 29958],
body=[529, 2587, 29958],
)


Expand Down
117 changes: 30 additions & 87 deletions model/legommender.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,62 @@
from typing import Type

import torch
from pigmento import pnt
from torch import nn

from loader.meta import Meta
from loader.status import Status
from model.common.base_module import BaseModule
from model.common.mediator import Mediator
from model.common.user_plugin import UserPlugin
from model.meta_config import LegommenderConfig
from model.operators.base_llm_operator import BaseLLMOperator
from model.operators.base_operator import BaseOperator
from model.predictors.base_predictor import BasePredictor
from loader.cacher.repr_cacher import ReprCacher
from loader.column_map import ColumnMap
from loader.embedding.embedding_hub import EmbeddingHub
from loader.data_hub import DataHub
from model.preparer import Preparer
from utils.function import combine_config
from utils.shaper import Shaper


class LegommenderMeta:
def __init__(
self,
item_encoder_class: Type[BaseOperator],
user_encoder_class: Type[BaseOperator],
predictor_class: Type[BasePredictor],
):
self.item_encoder_class = item_encoder_class
self.user_encoder_class = user_encoder_class
self.predictor_class = predictor_class


class LegommenderConfig:
def __init__(
self,
hidden_size,
user_config,
use_neg_sampling: bool = True,
neg_count: int = 4,
embed_hidden_size=None,
item_config=None,
predictor_config=None,
use_item_content: bool = True,
max_item_content_batch_size: int = 0,
same_dim_transform: bool = True,
page_size: int = 512,
**kwargs,
):
self.hidden_size = hidden_size
self.item_config = item_config
self.user_config = user_config
self.predictor_config = predictor_config or {}

self.use_neg_sampling = use_neg_sampling
self.neg_count = neg_count
self.use_item_content = use_item_content
self.embed_hidden_size = embed_hidden_size or hidden_size

self.max_item_content_batch_size = max_item_content_batch_size
self.same_dim_transform = same_dim_transform

self.page_size = page_size

if self.use_item_content:
if not self.item_config:
self.item_config = {}
# raise ValueError('item_config is required when use_item_content is True')
pnt('automatically set item_config to an empty dict, as use_item_content is True')


class Legommender(BaseModule):
def __init__(
self,
meta: LegommenderMeta,
status: Status,
config: LegommenderConfig,
column_map: ColumnMap,
embedding_manager: EmbeddingHub,
user_hub: DataHub,
item_hub: DataHub,
preparer: Preparer,
user_plugin: UserPlugin = None,
):
super().__init__()

self.preparer = preparer

"""initializing basic attributes"""
self.meta = meta
self.status = status
self.item_encoder_class = meta.item_encoder_class
self.user_encoder_class = meta.user_encoder_class
self.predictor_class = meta.predictor_class
# self.meta = meta
self.status = self.preparer.status
self.item_encoder_class = self.preparer.meta.item_encoder_class
self.user_encoder_class = self.preparer.meta.user_encoder_class
self.predictor_class = self.preparer.meta.predictor_class

self.use_neg_sampling = config.use_neg_sampling
self.neg_count = config.neg_count
self.use_neg_sampling = self.preparer.config.use_neg_sampling
self.neg_count = self.preparer.config.neg_count

self.config = config # type: LegommenderConfig
self.config = self.preparer.config # type: LegommenderConfig

self.embedding_manager = embedding_manager
self.embedding_table = embedding_manager.get_table()
self.embedding_manager = self.preparer.embedding_manager
self.embedding_table = self.preparer.embedding_manager.get_table()

self.user_hub = user_hub
self.item_hub = item_hub
self.user_hub = self.preparer.user_hub
self.item_hub = self.preparer.item_hub

self.column_map = column_map # type: ColumnMap
self.user_col = column_map.user_col
self.clicks_col = column_map.clicks_col
self.candidate_col = column_map.candidate_col
self.label_col = column_map.label_col
self.clicks_mask_col = column_map.clicks_mask_col
self.column_map = self.preparer.column_map # type: ColumnMap
self.user_col = self.column_map.user_col
self.clicks_col = self.column_map.clicks_col
self.candidate_col = self.column_map.candidate_col
self.label_col = self.column_map.label_col
self.clicks_mask_col = self.column_map.clicks_mask_col

"""initializing core components"""
self.flatten_mode = self.user_encoder_class.flatten_mode
self.user_encoder = self.prepare_user_module()
self.item_encoder = None
if self.config.use_item_content:
self.item_encoder = self.prepare_item_module()
self.user_encoder = self.prepare_user_module()
self.predictor = self.prepare_predictor()
self.mediator = Mediator(self)
# self.mediator = Mediator(self)

"""initializing extra components"""
self.user_plugin = user_plugin
Expand Down Expand Up @@ -174,6 +114,7 @@ def get_item_content(self, batch, col):
end = min((i + 1) * allow_batch_size, sample_size)
mask = None if attention_mask is None else attention_mask[start:end]
content = self.item_encoder(item_content[start:end], mask=mask)
# print(Structure().analyse_and_stringify(content))
item_contents[start:end] = content

if not self.llm_skip:
Expand Down Expand Up @@ -255,7 +196,7 @@ def __repr__(self):
def prepare_user_module(self):
user_config = self.user_encoder_class.config_class(**combine_config(
config=self.config.user_config,
hidden_size=self.config.hidden_size,
hidden_size=self.item_encoder.export_hidden_size(),
embed_hidden_size=self.config.embed_hidden_size,
input_dim=self.config.hidden_size,
))
Expand All @@ -268,6 +209,7 @@ def prepare_user_module(self):
hub=self.user_hub,
embedding_manager=self.embedding_manager,
target_user=True,
preparer=self.preparer,
)

def prepare_item_module(self):
Expand All @@ -283,6 +225,7 @@ def prepare_item_module(self):
hub=self.item_hub,
embedding_manager=self.embedding_manager,
target_user=False,
preparer=self.preparer,
)

def prepare_predictor(self):
Expand All @@ -298,7 +241,7 @@ def prepare_predictor(self):
embed_hidden_size=self.config.embed_hidden_size,
))

return self.predictor_class(config=predictor_config)
return self.predictor_class(config=predictor_config, preparer=self.preparer)

def get_parameters(self):
pretrained_parameters = []
Expand Down
56 changes: 56 additions & 0 deletions model/meta_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Type

from pigmento import pnt

from model.operators.base_operator import BaseOperator
from model.predictors.base_predictor import BasePredictor


class LegommenderMeta:
def __init__(
self,
item_encoder_class: Type[BaseOperator],
user_encoder_class: Type[BaseOperator],
predictor_class: Type[BasePredictor],
):
self.item_encoder_class = item_encoder_class
self.user_encoder_class = user_encoder_class
self.predictor_class = predictor_class


class LegommenderConfig:
def __init__(
self,
hidden_size,
user_config,
use_neg_sampling: bool = True,
neg_count: int = 4,
embed_hidden_size=None,
item_config=None,
predictor_config=None,
use_item_content: bool = True,
max_item_content_batch_size: int = 0,
same_dim_transform: bool = True,
page_size: int = 512,
**kwargs,
):
self.hidden_size = hidden_size
self.item_config = item_config
self.user_config = user_config
self.predictor_config = predictor_config or {}

self.use_neg_sampling = use_neg_sampling
self.neg_count = neg_count
self.use_item_content = use_item_content
self.embed_hidden_size = embed_hidden_size or hidden_size

self.max_item_content_batch_size = max_item_content_batch_size
self.same_dim_transform = same_dim_transform

self.page_size = page_size

if self.use_item_content:
if not self.item_config:
self.item_config = {}
# raise ValueError('item_config is required when use_item_content is True')
pnt('automatically set item_config to an empty dict, as use_item_content is True')
Loading

0 comments on commit ea27331

Please sign in to comment.