Attention
continuiti.networks.attention
Attention base class in continuiti.
Attention()
¶
Bases: Module
Base class for various attention implementations.
Attention assigns different parts of an input varying importance without set kernels. The importance of different components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g. scaled-dot-product attention).
forward(query, key, value, attn_mask=None)
abstractmethod
¶
Calculates the attention scores.
PARAMETER | DESCRIPTION |
---|---|
query |
query tensor; shape (batch_size, target_seq_length, hidden_dim)
TYPE:
|
key |
key tensor; shape (batch_size, source_seq_length, hidden_dim)
TYPE:
|
value |
value tensor; shape (batch_size, source_seq_length, hidden_dim)
TYPE:
|
attn_mask |
tensor indicating which values are used to calculate the output; shape (batch_size, target_seq_length, source_seq_length)
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
tensor containing the outputs of the attention implementation; shape (batch_size, target_seq_length, hidden_dim) |
Source code in src/continuiti/networks/attention.py
Created: 2024-08-22