Skip to content

prediction_saver

SavePredictions

Bases: BasePredictionWriter

A class that saves model predictions.

Attributes:

Name Type Description
path str

The path to save the output CSV file.

save_preview_samples bool

If True, save preview images.

keys List[str]

A list of keys.

Source code in fmcib/callbacks/prediction_saver.py
class SavePredictions(BasePredictionWriter):
    """
    A class that saves model predictions.

    Attributes:
        path (str): The path to save the output CSV file.
        save_preview_samples (bool): If True, save preview images.
        keys (List[str]): A list of keys.
    """

    def __init__(self, path: str, save_preview_samples: bool = False, keys: List[str] = None):
        """
        Initialize an instance of the class.

        Args:
            path (str): The path to save the output CSV file.
            save_preview_samples (bool, optional): A flag indicating whether to save preview samples. Defaults to False.
            keys (List[str], optional): A list of keys. Defaults to None.

        Raises:
            None

        Returns:
            None
        """
        super().__init__("epoch")
        self.output_csv = Path(path)
        self.keys = keys
        self.save_preview_samples = save_preview_samples
        self.output_csv.parent.mkdir(parents=True, exist_ok=True)

    def save_preview_image(self, data, tag):
        """
        Save a preview image to a specified directory.

        Args:
            self (object): The object calling the function. (self in Python)
            data (tuple): A tuple containing the image data and its corresponding tag.
            tag (str): The tag for the image.

        Returns:
            None

        Raises:
            None
        """
        self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        image, _ = data
        image = handle_image(image)
        fp = self.output_dir / f"{tag}.png"
        torchvision.utils.save_image(image, fp)

    def write_on_epoch_end(
        self,
        trainer,
        pl_module: "LightningModule",
        predictions: List[Any],
        batch_indices: List[Any],
    ):
        """
        Write predictions on epoch end.

        Args:
            self: The instance of the class.
            trainer: The trainer object.
            pl_module (LightningModule): The Lightning module.
            predictions (List[Any]): A list of prediction values.
            batch_indices (List[Any]): A list of batch indices.

        Raises:
            AssertionError: If 'predict' is not present in pl_module.datasets.
            AssertionError: If 'data' is not defined in pl_module.datasets.

        Returns:
            None
        """
        rows = []
        assert "predict" in pl_module.datasets, "`data` not defined"
        dataset = pl_module.datasets["predict"]
        predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]

        for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
            for i, v in enumerate(pred):
                row[f"pred_{i}"] = v.item()

            rows.append(row)

            # Save image previews
            if idx <= self.save_preview_samples:
                input = dataset[idx]
                self.save_preview_image(input, idx)

        df = pd.DataFrame(rows)
        df.to_csv(self.output_csv)

__init__(path, save_preview_samples=False, keys=None)

Initialize an instance of the class.

Parameters:

Name Type Description Default
path str

The path to save the output CSV file.

required
save_preview_samples bool

A flag indicating whether to save preview samples. Defaults to False.

False
keys List[str]

A list of keys. Defaults to None.

None

Returns:

Type Description

None

Source code in fmcib/callbacks/prediction_saver.py
def __init__(self, path: str, save_preview_samples: bool = False, keys: List[str] = None):
    """
    Initialize an instance of the class.

    Args:
        path (str): The path to save the output CSV file.
        save_preview_samples (bool, optional): A flag indicating whether to save preview samples. Defaults to False.
        keys (List[str], optional): A list of keys. Defaults to None.

    Raises:
        None

    Returns:
        None
    """
    super().__init__("epoch")
    self.output_csv = Path(path)
    self.keys = keys
    self.save_preview_samples = save_preview_samples
    self.output_csv.parent.mkdir(parents=True, exist_ok=True)

save_preview_image(data, tag)

Save a preview image to a specified directory.

Parameters:

Name Type Description Default
self object

The object calling the function. (self in Python)

required
data tuple

A tuple containing the image data and its corresponding tag.

required
tag str

The tag for the image.

required

Returns:

Type Description

None

Source code in fmcib/callbacks/prediction_saver.py
def save_preview_image(self, data, tag):
    """
    Save a preview image to a specified directory.

    Args:
        self (object): The object calling the function. (self in Python)
        data (tuple): A tuple containing the image data and its corresponding tag.
        tag (str): The tag for the image.

    Returns:
        None

    Raises:
        None
    """
    self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
    self.output_dir.mkdir(parents=True, exist_ok=True)
    image, _ = data
    image = handle_image(image)
    fp = self.output_dir / f"{tag}.png"
    torchvision.utils.save_image(image, fp)

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)

Write predictions on epoch end.

Parameters:

Name Type Description Default
self

The instance of the class.

required
trainer

The trainer object.

required
pl_module LightningModule

The Lightning module.

required
predictions List[Any]

A list of prediction values.

required
batch_indices List[Any]

A list of batch indices.

required

Raises:

Type Description
AssertionError

If 'predict' is not present in pl_module.datasets.

AssertionError

If 'data' is not defined in pl_module.datasets.

Returns:

Type Description

None

Source code in fmcib/callbacks/prediction_saver.py
def write_on_epoch_end(
    self,
    trainer,
    pl_module: "LightningModule",
    predictions: List[Any],
    batch_indices: List[Any],
):
    """
    Write predictions on epoch end.

    Args:
        self: The instance of the class.
        trainer: The trainer object.
        pl_module (LightningModule): The Lightning module.
        predictions (List[Any]): A list of prediction values.
        batch_indices (List[Any]): A list of batch indices.

    Raises:
        AssertionError: If 'predict' is not present in pl_module.datasets.
        AssertionError: If 'data' is not defined in pl_module.datasets.

    Returns:
        None
    """
    rows = []
    assert "predict" in pl_module.datasets, "`data` not defined"
    dataset = pl_module.datasets["predict"]
    predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]

    for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
        for i, v in enumerate(pred):
            row[f"pred_{i}"] = v.item()

        rows.append(row)

        # Save image previews
        if idx <= self.save_preview_samples:
            input = dataset[idx]
            self.save_preview_image(input, idx)

    df = pd.DataFrame(rows)
    df.to_csv(self.output_csv)