Skip to content

Attention

continuiti.networks.attention

Attention base class in continuiti.

Attention()

Bases: Module

Base class for various attention implementations.

Attention assigns different parts of an input varying importance without set kernels. The importance of different components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g. scaled-dot-product attention).

Source code in src/continuiti/networks/attention.py
def __init__(self):
    super().__init__()

forward(query, key, value, attn_mask=None) abstractmethod

Calculates the attention scores.

PARAMETER DESCRIPTION
query

query tensor; shape (batch_size, target_seq_length, hidden_dim)

TYPE: Tensor

key

key tensor; shape (batch_size, source_seq_length, hidden_dim)

TYPE: Tensor

value

value tensor; shape (batch_size, source_seq_length, hidden_dim)

TYPE: Tensor

attn_mask

tensor indicating which values are used to calculate the output; shape (batch_size, target_seq_length, source_seq_length)

TYPE: Tensor DEFAULT: None

RETURNS DESCRIPTION
Tensor

tensor containing the outputs of the attention implementation; shape (batch_size, target_seq_length, hidden_dim)

Source code in src/continuiti/networks/attention.py
@abstractmethod
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: torch.Tensor = None,
) -> torch.Tensor:
    """Calculates the attention scores.

    Args:
        query: query tensor; shape (batch_size, target_seq_length, hidden_dim)
        key: key tensor; shape (batch_size, source_seq_length, hidden_dim)
        value: value tensor; shape (batch_size, source_seq_length, hidden_dim)
        attn_mask: tensor indicating which values are used to calculate the output;
            shape (batch_size, target_seq_length, source_seq_length)

    Returns:
        tensor containing the outputs of the attention implementation;
            shape (batch_size, target_seq_length, hidden_dim)
    """

Last update: 2024-08-22
Created: 2024-08-22