Skip to content

Commit

Permalink
fixed "TypeError: 'NoneType' object is not iterable" for reward model… (
Browse files Browse the repository at this point in the history
LAION-AI#3587)

During training the reward model using the `oasst` dataset, using the
following command:
```bash
python trainer_rm.py --configs defaults_rm oasst-rm-1-pythia-1.4b
```
the data loader raises an error:
**TypeError: 'NoneType' object is not iterable**
```
│ Open-Assistant/model/model_training/custom_datasets/oasst_dataset.py:102 in leaf_filter│
│                                                                                        │
│    99   │   │   │   │   # their (ranked) replies as possible continuations.            │
│   100   │   │   │   │   return (                                                       │
│   101   │   │   │   │   │   thread[-1].role == "prompter"                              │
│ > 102   │   │   │   │   │   and len([r for r inthread[-1].replies if r.rank is not None]) > 1                     │
│   103   │   │   │   │   │   and thread_filter(thread)                                  │
│   104   │   │   │   │   )                                                              │
│   105   │   │   │   elif mode == "rl":                                                 │
╰────────────────────────────────────────────────────────────────────────────────────────╯
```
I inserted a condition in line 100 as below to fix this issue and submit
this PR
```python
if thread[-1].replies is None:
    return False
```
  • Loading branch information
egg-west authored Jul 21, 2023
1 parent a11971c commit 11a1842
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions model/model_training/custom_datasets/oasst_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def leaf_filter(thread: list[ExportMessageNode]) -> bool:
elif mode == "rm":
# for reward models we use thread-fragments ending on prompter messages as prefix and
# their (ranked) replies as possible continuations.
if thread[-1].replies is None:
return False
return (
thread[-1].role == "prompter"
and len([r for r in thread[-1].replies if r.rank is not None]) > 1
Expand Down

0 comments on commit 11a1842

Please sign in to comment.