Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get issues from all samples in the dataset for semantic segmentation tasks without memory issues #842

Closed
hamzagorgulu opened this issue Sep 5, 2023 · 12 comments
Labels
help-wanted We need your help to add this, but it may be more challenging than a "good first issue" question A question for Cleanlab maintainers

Comments

@hamzagorgulu
Copy link

I am trying to get issues from a 180k image dataset with 20 classes. But the prediction numpy array size per image is about 100 MB. So I am not able to get the predictions because I would need 18 TB space for this. I used fp16 bytes to reduce the size of the array but still so huge to handle.

Do you have any suggestion that would reduce the size of the predictions and still acceptable by cleanlab in terms of semantic segmentation?

@hamzagorgulu hamzagorgulu added the question A question for Cleanlab maintainers label Sep 5, 2023
@jwmueller
Copy link
Member

Could you share the code you are trying to run?

One thing you can try (I'm not sure if this works) is:

Pass this as a memory-mapped object for pred_probs and labels like:

  • Zarr array loaded with zarr.convenience.open(YOURFILE.zarr, mode="r"), or
  • memmap array loaded with np.load(YOURFILE.npy, mmap_mode="r").

Another idea is if you're interested primarily in computing a label quality score for each image, i.e. via segmentation.rank.get_label_quality_scores, then you can actually run this method in small batches of images at a time to get their label-quality-scores independently of the rest of the images, ie:

scores[i:j] = segmentation.rank.get_label_quality_scores(labels[i:j], pred_probs[i:j]

where j-i is a small enough batch-size you can load into memory.

@hamzagorgulu
Copy link
Author

hamzagorgulu commented Sep 6, 2023

Thanks for the reply. I am not sure if using YOURFILE.zarr would work because I am only experimenting now with 4.5k out of sample predictions. However, I need to be able to apply it to 180k ideally. For the segmentation.rank.get_label_quality_scores, I actually did not know it works different from find_label_issues. If I understand correctly, it is just an algorithm that evaluates the labels independently and do not compare each other, so does not have to be inserted together.

Here is how I read the labels and pred_probs and get the issues. I already use memmap for it.

my_label_filepaths = "/mnt/data/projects/hrnetv2/data/gt_and_predprobs/labels.npy"
my_labels = np.load(my_label_filepaths, mmap_mode='r')

my_pred_probs_filepaths = "/mnt/data/projects/hrnetv2/data/gt_and_predprobs/merged.npy"
my_pred_probs = np.load(my_pred_probs_filepaths, mmap_mode='r', allow_pickle=False)

issues = find_label_issues(my_labels, my_pred_probs, downsample = 1, n_jobs=None, batch_size=10)

And here here is the error I get with this code:

Traceback (most recent call last):
  File "segmentation.py", line 44, in <module>
    issues = find_label_issues(my_labels, my_pred_probs, downsample = 1, n_jobs=None, batch_size=10)
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/segmentation/filter.py", line 157, in find_label_issues
    pre_labels, pre_pred_probs = flatten_and_preprocess_masks(pre_labels, pre_pred_probs)
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/segmentation/filter.py", line 127, in flatten_and_preprocess_masks
    chunk_size = 50  # Define a suitable chunk size
numpy.core._exceptions.MemoryError: Unable to allocate 41.7 GiB for an array with shape (1119744000, 20) and data type float16

Before getting this error, I realized the flatten_and_preprocess_masks function causes the error while it is flattening the pred_probs arrays and tried to chunk it but I guess it does not work.

Additionaly, I tried using find_label_issues_batched as follows:

issues = find_label_issues_batched(
    labels=my_labels, pred_probs=my_pred_probs, batch_size=batch_size
)

But the error I get is again related to memory but weird because it changes the shape of the pred_probs:

Traceback (most recent call last):
  File "segmentation.py", line 40, in <module>
    issues = find_label_issues_batched(
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/experimental/label_issues_batched.py", line 231, in find_label_issues_batched
    _ = lab.score_label_quality(labels_batch, pred_probs_batch)
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/experimental/label_issues_batched.py", line 565, in score_label_quality
    scores = _compute_label_quality_scores(
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/rank.py", line 166, in _compute_label_quality_scores
    label_quality_scores = scoring_func(**scoring_inputs)
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/rank.py", line 510, in get_self_confidence_for_each_label
    return np.array([pred_probs[i, l] for i, l in enumerate(labels)])
  File "/mnt/data/projects/hrnetv2/data/cleanlab/cleanlab/rank.py", line 510, in <listcomp>
    return np.array([pred_probs[i, l] for i, l in enumerate(labels)])
  File "/mnt/data/venvs/hrnet/lib/python3.8/site-packages/numpy/core/memmap.py", line 334, in __getitem__
    res = super().__getitem__(index)
numpy.core._exceptions.MemoryError: Unable to allocate 115. GiB for an array with shape (432, 576, 432, 576) and data type float16

@hamzagorgulu
Copy link
Author

hamzagorgulu commented Sep 6, 2023

I can actually get the label quality scores and also issues without problem now. The memory issue occurs when I try to get the label issues directly using find_label_issues. I am not usre if it would still work for 180k images but here is how I get the issues:

image_scores, pixel_scores = get_label_quality_scores(my_labels, my_pred_probs) 
issues_from_score = issues_from_scores(image_scores, pixel_scores, threshold=0.5)

What I extra do here is to specify a threshold value with 0.5. Do you think this is the similar way as in find_label_issues?

@hamzagorgulu
Copy link
Author

I have used .zarr compression and it compressed the np predictions arrays from 100MB to 5MB. I confirm that the compression is lossless. But I still could not insert it into cleanlab since it still needs 900GB storage space which is huge.

@jwmueller
Copy link
Member

Thanks for providing the additional information, your additional workaround sounds good to me and I'd proceed with that for your data.

@vdlad is looking further into this issue. We suspect the bottleneck is in this line of code:

pred_probs_flat = np.moveaxis(pred_probs, 0, 1).reshape(num_classes, -1)

where either flatten(), np.moveaxis, or reshape() are causing the memmap array to be loaded entirely into memory all at once and thus causing RAM error. Appreciate also if you have any suggestions to improve this line of code for mem-mapped/zarr arrays based on your experience!

@hamzagorgulu
Copy link
Author

hamzagorgulu commented Sep 10, 2023

The chunking approach I used was as follows:

chunk_size = 500
image_scores_lst = []
zarr_file_path = "issues_from_score.zarr"

for idx in range(0, len(my_labels), chunk_size): 
    print(f"Starting {(int(idx/chunk_size)+1)}. loop")
    label_chunk = my_labels[idx:idx+chunk_size]
    pred_prob_chunk = my_pred_probs[idx:idx+chunk_size]
    image_scores, pixel_scores = get_label_quality_scores(label_chunk, pred_prob_chunk)
    issues_from_score = issues_from_scores(image_scores, pixel_scores, threshold=0.5)

    # save issues_from_score with .zarr
    append_with_zarr(zarr_file_path, issues_from_score, dtype = np.uint8)

    image_scores_lst.append(image_scores)

As seen, I used chunks when getting the issues. However, issues_from_score mask corresponds to a (W, H) array and the memory cannot handle 180k * (W, H). Thus, I have also stored the issues_from_score array chunks using .zarr. The functions for reading and storing as .zarr file are as follows:

def append_with_zarr(filename, array_to_be_appended, dtype):
    """
    Appends a given NumPy array to an existing Zarr array or creates a new Zarr array.

    Args:
        filename (str): Path to the Zarr file.
        array_to_be_appended (np.ndarray): NumPy array to append.
        dtype (str or np.dtype): Data type to which the input array should be cast.
    """
    compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle = Blosc.NOSHUFFLE)
    
    # Initialize Zarr array if it doesn't exist
    if not os.path.exists(filename):
        initial_shape = (0,) + array_to_be_appended.shape[1:]
        zarr_array = zarr.zeros(shape=initial_shape, chunks=array_to_be_appended.shape, dtype=np.float16, store=filename, compressor=compressor)
    else:
        zarr_array = zarr.open(filename, mode='a', compressor=compressor)

    array_to_be_appended = array_to_be_appended.astype(dtype)
    
    new_shape = (zarr_array.shape[0] + array_to_be_appended.shape[0],) + array_to_be_appended.shape[1:]
    zarr_array.resize(new_shape)
    zarr_array[-array_to_be_appended.shape[0]:] = array_to_be_appended

def read_zarr_file(filename):
    """
    Reads a Zarr file and returns its content.

    Args:
        filename (str): Path to the Zarr file.

    Returns:
        zarr.core.Array: Zarr array containing the data.
    """
    zarr_file = zarr.open(filename, mode='r') 
    return zarr_file

I think I will be able to get the results but I am not sure if this approach is the same as find_label_issues() function. Because in the approach I provide, the chunks are not out of sample predictions but sliced pred_probs from the same model. So, would you think this is a good approach? Or maybe should I combine pred_probs slices from different models and then give to the get_label_quality_scores() function? I am not sure if this would work since I did not have time to explore how the algorithm works.

@hamzagorgulu
Copy link
Author

Hi there @vdlad. Have you been able to implement such approach?

I have tried to use chunking method to get the results and save them on the fly. Even though I got results, they are worse in terms of finding label mistakes than the method I used above(get_label_quality_scores and then issues_from_scores). So I am not sure what I did is correct. Here is the chunking approach for find_label_issues:

find_label_issues_path = "path_to_zarr_file.zarr"
chunk=100
for idx in range(0, len(pred_probs_dict["fold1"]), chunk):
    print(idx)
    pred_probs_stack = np.concatenate([pred_probs_dict[fold][idx:idx+chunk] for fold in folds], axis=0) # lenght of chunk from every folds
    labels_stack = np.concatenate([labels_dict[fold][idx:idx+chunk] for fold in folds], axis=0)
    issues = find_label_issues(labels=labels_stack, pred_probs=pred_probs_stack, njobs=8, downsample=2)
    append_with_zarr(find_label_issues_path, issues, dtype="bool")

In summary, I concat the chunks from different folds which represent predictions from different models, and then insert into find_label_issues function. Then I append the results into .zarr file.

@jwmueller
Copy link
Member

@hamzagorgulu For your data, I recommend simply doing:

image_scores, pixel_scores = get_label_quality_scores(my_labels, my_pred_probs) 
issues_from_score = issues_from_scores(image_scores, pixel_scores, threshold=T)

where you'll want to visually estimate what is a good threshold T by looking at the images whose label-quality-scores lie around a candidate T value.

This will not give the same results as:

segmentation.filter.find_label_issues(..., downsample=D)

which is intended to help additionally estimate the number of mislabeled images as well.
The latter method should NOT be run in small batches as you've done, it needs access to the full dataset. One thing you could try is increasing the value of D beyond 1 to say 16, which should require less RAM. But anyway what you've done with issues_from_scores should work well for your dataset if you invest effort to select the right T by inspecting a few values.

@hamzagorgulu
Copy link
Author

Thanks for the answer. I ask to be sure about it. I dont have to give the out of sample predictions for the first approach right? I can iterate over folds one by one.

@jwmueller
Copy link
Member

I dont have to give the out of sample predictions for the first approach right? I can iterate over folds one by one.

Yep because the get_label_quality_scores can be done in independent mini-batches. But still should do cross-validation in general to ensure you aren't providing pred_probs for the same data the segmentation model was trained on (you can do that but the label-error-detection will simply be less accurate due to overfitting).

@jwmueller
Copy link
Member

Tracking a fix for the original issue here: #863

@jwmueller
Copy link
Member

This is now completed in: #885

Follow-up improvement: #918

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help-wanted We need your help to add this, but it may be more challenging than a "good first issue" question A question for Cleanlab maintainers
Projects
None yet
Development

No branches or pull requests

2 participants