23 lines
700 B
Python
23 lines
700 B
Python
import numpy as np
|
|
|
|
|
|
def intersection_over_union(pred, true):
|
|
"""
|
|
Calculates intersection and union for a batch of images.
|
|
|
|
Args:
|
|
pred (torch.Tensor): a tensor of predictions
|
|
true (torc.Tensor): a tensor of labels
|
|
|
|
Returns:
|
|
intersection (int): total intersection of pixels
|
|
union (int): total union of pixels
|
|
"""
|
|
valid_pixel_mask = true.ne(255) # valid pixel mask
|
|
true = true.masked_select(valid_pixel_mask).to("cpu")
|
|
pred = pred.masked_select(valid_pixel_mask).to("cpu")
|
|
|
|
# Intersection and union totals
|
|
intersection = np.logical_and(true, pred)
|
|
union = np.logical_or(true, pred)
|
|
return intersection.sum() / union.sum() |