pytorch tensorboard add_graph() 错误

错误

RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module’s inputs. Consider using a constant container instead (e.g. for list, use a tuple instead. for dict, use a NamedTuple instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.

可能错误原因

网络输出为list或dict出现错误

解决方案

将输出用tuple和NamedTuple包裹。

from collections import namedtuple
from typing import Any

import torch


# pylint: disable = abstract-method
class ModelWrapper(torch.nn.Module):
    """
    Wrapper class for model with dict/list rvalues.
    """

    def __init__(self, model: torch.nn.Module) -> None:
        """
        Init call.
        """
        super().__init__()
        self.model = model

    def forward(self, input_x: torch.Tensor) -> Any:
        """
        Wrap forward call.
        """
        data = self.model(input_x)

        if isinstance(data, dict):
            data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys()))  # type: ignore
            data = data_named_tuple(**data)  # type: ignore

        elif isinstance(data, list):
            data = tuple(data)

        return data

model_wrapper = ModelWrapper(your_model)
writer.add_graph(model_wrapper, input_image)

参考:https://discuss.pytorch.org/t/a-tensorboard-problem-about-use-add-graph-method-for-deeplab-v3-in-torchvision/95808

匿名

发表评论

匿名网友