Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple input and output representations #22

Open
cheind opened this issue Dec 10, 2021 · 2 comments
Open

Decouple input and output representations #22

cheind opened this issue Dec 10, 2021 · 2 comments

Comments

@cheind
Copy link
Owner

cheind commented Dec 10, 2021

currently we assume that what we get as input is what we will predict as output (just shifted). However, thinking towards other research areas it might make sense that we rework that more generally:

model
  input: BxIxT
  output: BxQxT

where I might match Q but does not have to. In the training we would then have code like the following

def training_step(batch): 
  inputs = batch['x']
  if 't' in batch:
    targets = batch['t'] # allows us to provide alternative targets
  elif I == Q:
    targets = inputs[..., 1:]
    inputs = inputs[..., :-1]
  else:
    raise ValueError(...)

  logits = self.forward(inputs)
  loss = ce(logits, targets)

what's more is that we need to think about input transformers. Currently we use one-hot encoding hardwired into the model. We might instead consider a differentiable input_transform that is given to the model upon initialization. This would allow us to use differentiable embedding strategies.

@cheind
Copy link
Owner Author

cheind commented Dec 12, 2021



dataset -> model -> loss

model:
    input: BxIxT
    input_transform: fn(BxKxT) -> BxIxT
    condition: BxCxT
    output: BxQxT


def loss(inputs, outputs):
    if 't' in batch:
        targets = batch['t'][..., 1:] # BxQxT or BxT
    else:
        targets = batch['x'][..., 1:] # 'x' either BxQxT or BxT
    targets = inputs[..., 1:] # BxT
    logits = outputs[..., :-1]
    preds = sample(logits) # BxIxT
    ce(preds, targets) # BxQxT, BxQxT


def training_step(batch):
    inputs = batch['x'] # BxIxT
    condition = batch['c'] # BxCxT
    logits = self.forward(inputs)
    loss(...)

def forward(self, inputs, cond):
    inputs = self.input_transform(inputs)
    outputs = self.encode(inputs)

@cheind
Copy link
Owner Author

cheind commented Dec 12, 2021

would that also work for different model output interpretation such as #24

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant