Skip to content

Deep residual network

continuiti.networks.deep_residual_network

Deep residual network in continuiti.

ResidualLayer(width, act=None, device=None)

Bases: Module

Residual layer.

PARAMETER DESCRIPTION
width

Width of the layer.

TYPE: int

act

Activation function.

TYPE: Optional[Module] DEFAULT: None

device

Device.

TYPE: Optional[device] DEFAULT: None

Source code in src/continuiti/networks/deep_residual_network.py
def __init__(
    self,
    width: int,
    act: Optional[torch.nn.Module] = None,
    device: Optional[torch.device] = None,
):
    super().__init__()
    self.layer = torch.nn.Linear(width, width, device=device)
    self.act = act or torch.nn.GELU()
    self.norm = torch.nn.LayerNorm(width, device=device)

forward(x)

Forward pass.

Source code in src/continuiti/networks/deep_residual_network.py
def forward(self, x: torch.Tensor):
    """Forward pass."""
    return self.norm(self.act(self.layer(x))) + x

DeepResidualNetwork(input_size, output_size, width, depth, act=None, device=None)

Bases: Module

Deep residual network.

PARAMETER DESCRIPTION
input_size

Size of input tensor

TYPE: int

output_size

Size of output tensor

TYPE: int

width

Width of hidden layers

TYPE: int

depth

Number of hidden layers

TYPE: int

act

Activation function

TYPE: Optional[Module] DEFAULT: None

device

Device.

TYPE: Optional[device] DEFAULT: None

Source code in src/continuiti/networks/deep_residual_network.py
def __init__(
    self,
    input_size: int,
    output_size: int,
    width: int,
    depth: int,
    act: Optional[torch.nn.Module] = None,
    device: Optional[torch.device] = None,
):
    assert depth >= 1, "DeepResidualNetwork has at least depth 1."
    super().__init__()

    self.act = act or torch.nn.GELU()
    self.first_layer = torch.nn.Linear(input_size, width, device=device)
    self.hidden_layers = torch.nn.ModuleList(
        [
            ResidualLayer(
                width,
                act=self.act,
                device=device,
            )
            for _ in range(1, depth)
        ]
    )
    self.last_layer = torch.nn.Linear(width, output_size, device=device)

forward(x)

Forward pass.

Source code in src/continuiti/networks/deep_residual_network.py
def forward(self, x):
    """Forward pass."""
    x = self.first_layer(x)
    x = self.act(x)
    for layer in self.hidden_layers:
        x = layer(x)
    return self.last_layer(x)

Last update: 2024-08-20
Created: 2024-08-20