diff --git a/README.md b/README.md
index 8c1d5b920..a418d09f0 100644
--- a/README.md
+++ b/README.md
@@ -147,6 +147,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
| 交通预测 | [TGCN 交通流量预测](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/tgcn) | 数据驱动 | GCN & CNN | 监督学习 | [PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip) | - |
| 遥感图像分割 | [UNetFormer 遥感图像分割](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/unetformer) | 数据驱动 | UNetFormer | 监督学习 | [Vaihingen](https://paperswithcode.com/dataset/isprs-vaihingen) | [Paper](https://github.com/WangLibo1995/GeoSeg) |
| 生成模型| [图像生成中的梯度惩罚应用](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/wgan_gp)|数据驱动|WGAN GP|监督学习|[Data1](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)
[Data2](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)| [Paper](https://github.com/igul222/improved_wgan_training) |
+| 遥感图像分割 | [UTAE 遥感时序语义/全景分割](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/UTAE/) | 数据驱动 | UTAE | 监督学习 | [PASTIS](https://zenodo.org/records/5012942) | [Paper](https://arxiv.org/abs/2107.07933) |
diff --git a/docs/index.md b/docs/index.md
index 0fb15089b..d69e8ec80 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -164,6 +164,7 @@
| 遥感图像分割 | [UNetFormer分割图像](./zh/examples/unetformer.md) | 数据驱动 | UNetformer | 监督学习 | [Vaihingen](https://paperswithcode.com/dataset/isprs-vaihingen) | [Paper](https://github.com/WangLibo1995/GeoSeg) |
| 交通预测 | [TGCN 交通流量预测](./zh/examples/tgcn.md) | 数据驱动 | GCN & CNN | 监督学习 | [PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip) | - |
| 生成模型| [图像生成中的梯度惩罚应用](./zh/examples/wgan_gp.md)|数据驱动|WGAN GP|监督学习|[Data1](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)
[Data2](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)| [Paper](https://github.com/igul222/improved_wgan_training) |
+ | 遥感图像分割 | [UTAE 遥感时序语义/全景分割](./zh/examples/UTAE.md) | 数据驱动 | UTAE | 监督学习 | [PASTIS](https://zenodo.org/records/5012942) | [Paper](https://arxiv.org/abs/2107.07933) |
=== "化学科学"
diff --git a/docs/zh/examples/UTAE.md b/docs/zh/examples/UTAE.md
new file mode 100644
index 000000000..2c40d0684
--- /dev/null
+++ b/docs/zh/examples/UTAE.md
@@ -0,0 +1,142 @@
+# 农作物种植情况实时监测
+
+!!! note
+
+ 运行模型前请在 [PASTIS官网](https://zenodo.org/records/5012942) 中下载PASTIS数据集,并将其放在 `./UTAE/data/` 文件夹下。
+
+=== "模型训练命令"
+
+ ``` sh
+ # 语义分割任务
+ python train_semantic.py \
+ --dataset_folder "./data/PASTIS" \
+ --epochs 100 \
+ --batch_size 2 \
+ --num_workers 0 \
+ --display_step 10
+ # 全景分割任务
+ python train_panoptic.py \
+ --dataset_folder "./data/PASTIS" \
+ --epochs 100 \
+ --batch_size 2 \
+ --num_workers 0 \
+ --warmup 5 \
+ --l_shape 1 \
+ --display_step 10
+ ```
+
+=== "模型评估命令"
+
+ ``` sh
+ # 语义分割任务
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/utae/semantic.pdparams -P ./pretrained/
+ python test_semantic.py \
+ --weight_file ./pretrained/semantic.pdparams \
+ --dataset_folder "./data/PASTIS" \
+ --device gpu
+ --num_workers 0
+ # 全景分割任务
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/utae/panoptic.pdparams -P ./pretrained/
+ python test_panoptic.py \
+ --weight_folder ./pretrained/panoptic.pdparams \
+ --dataset_folder ./data/PASTIS \
+ --batch_size 2 \
+ --num_workers 0 \
+ --device gpu
+ ```
+
+| 预训练模型 | 指标 |
+|:--| :--|
+| [语义分割任务](https://paddle-org.bj.bcebos.com/paddlescience/models/utae/semantic.pdparams) | OA (Over all Accuracy): 86.7%
mIoU (mean Intersection over Union): 72.6% |
+| [全景分割任务](https://paddle-org.bj.bcebos.com/paddlescience/models/utae/panoptic.pdparams) | SQ (Segmentation Quality): 83.8
RQ (Recognition Quality): 58.9
PQ (Panoptic Quality): 49.7 |
+
+## 背景简介
+对农作物种植分布和生长状态进行高效、精准的监测,是现代智慧农业和粮食安全领域的核心需求。传统的人工勘察方法耗时费力,而利用单时相卫星影像进行分析的方法,难以应对云层遮挡问题,也无法捕捉作物在整个生长周期中的动态变化规律。
+
+卫星图像时间序列(Satellite Image Time Series, SITS)技术为解决这一难题提供了新的途径。通过持续采集同一区域在不同时间的多光谱影像,SITS数据蕴含了作物从播种、出苗、生长、成熟到收割的全过程光谱和纹理信息。然而,SITS数据具有时序长、维度高、时空关联性强等特点,如何从中高效地提取特征并进行精确的像素级分类(语义分割)是一项重大的技术挑战。
+
+本项目基于模型U-TAE(U-Net Temporal Attention Encoder),利用PaddlePaddle深度学习框架进行实现,旨在构建一个端到端的解决方案,对PASTIS数据集中的卫星影像时间序列进行语义分割,从而实现对多种农作物种植情况的自动化、高精度识别与监测。该技术可广泛应用于农业资源调查、产量预估、灾害评估等领域,具有重要的实用价值。
+
+## 模型原理
+
+本章节仅对U-TAE的模型原理进行简单介绍,详细的理论推导请参考论文:[Panoptic Segmentation of Satellite Image Time Series with Convolutional Temporal Attention Networks](https://arxiv.org/abs/2107.07933)
+
+### 1. 整体结构
+
+UTAE(U-Net Temporal Attention Encoder)采用编码器-解码器架构,专为卫星图像时间序列语义分割设计:
+
+- **编码器**:使用轻量化的ResNet-18,提取单时相的空间特征。
+- **解码器**:集成U-TAE模块,利用时间注意力机制聚合多时相的全局上下文信息。
+- **输出**:生成与输入相同分辨率的像素级类别概率图。
+
+
+
+### 2. 时间注意力机制(Temporal Attention)
+
+对于长度为 $T$ 的帧序列,UTAE在解码阶段为每一帧计算帧间相似度权重,实现自适应的时序信息聚合:
+
+- **Query**:当前帧的特征 $\mathbf{Q}$
+- **Key / Value**:全部帧的特征 $\mathbf{K}, \mathbf{V}$
+
+计算步骤如下:
+
+$$
+\text{权重} = \text{Softmax}(\mathbf{Q} \cdot \mathbf{K}^\top)
+$$
+
+然后,利用这些权重对全部帧的特征进行加权求和得到聚合特征:
+
+$$
+\mathbf{F}_{\text{agg}} = \sum_{t=1}^{T} \alpha_t \mathbf{V}_t, \quad \text{其中} \quad \alpha_t = \text{Softmax}(\mathbf{Q} \cdot \mathbf{K}_t^\top)
+$$
+
+该机制能自动抑制云层、阴影等低质量帧,提升作物边界的清晰度。
+
+### 3. 全局-局部注意力块(GLTB)
+
+每个解码器层包含两个并行分支:
+
+- **全局分支**:采用多头自注意力(Multi-Head Self-Attention)机制,建模田块级的长程依赖关系。
+
+- **局部分支**:使用 $3 \times 3$ 深度可分离卷积,注重边缘和细节信息的保留。
+
+两个分支的输出通过逐元素相加融合,既保持全局上下文,又保留局部纹理细节。
+
+### 4. 实时推理优化
+
+为实现高效实时推理,模型采用以下优化策略:
+
+- **轻量级骨干**:ResNet-18,参数量小于12M。
+- **帧间共享权重**:在同一序列中,Key和Value只计算一次,避免重复计算。
+- **滑动窗口推理**:将大图划分为多个块进行逐块推理,确保显存占用恒定。
+
+## 数据集介绍
+
+PASTIS数据集,该数据集由2433个 $10\times128\times128$ 形状的多光谱图像序列组成。每个序列包含2018年9月至2019年11月之间的38至61个观察点,总计超过20亿像素。获取间隔时间不均匀,平均为5天。这种缺乏规律性的现象是由于卫星数据提供商对大量云层覆盖的采集进行了自动过滤。该数据集覆盖4000多平方公里,图像来自法国四个不同地区,气候和作物分布多样。
+数据集可通过 [PASTIS官网](https://zenodo.org/records/5012942) 下载。
+
+## 模型实现
+
+### 模型构建
+
+本案例基于 UTAE(U-TAE) 实现,用 PaddleScience 封装如下:
+
+``` py linenums="12" title="examples/UTAE/src/backbones/utae.py"
+--8<--
+examples/UTAE/src/backbones/utae.py:12:177
+--8<--
+```
+
+## 可视化结果
+在 PASTIS 数据集上,本案例复现了全景分割预测与语义分割预测的可视化结果如图所示:
+
+
+
+(a)原始图像 (b)标注(真实标签)(c) 全景分割预测 (d) 语义分割预测
+
+上图展示了 PASTIS 数据集上的农田地块分割结果。在图中用不同颜色表示不同的地块。绿色圈出的位置代表大块地被错误识别为单一地块;红色圈出的位置代表很多细长地块未被正确检测;蓝色圈出的位置展示了 全景分割优于语义分割的情况。模型在区域边界检测方面具有较好表现,尤其在复杂边界的恢复上有所优势。但在面对细长、破碎或复杂地块时,仍然存在挑战,容易导致置信度下降或检测失败。
+## 参考文献
+
+- U-TAE 原论文:[Panoptic Segmentation of Satellite Image Time Series with Convolutional Temporal Attention Networks](https://arxiv.org/abs/2107.07933)
+- 源代码实现:[https://github.com/VSainteuf/utae-paps](https://github.com/VSainteuf/utae-paps)
+- 数据集与基准:[https://github.com/VSainteuf/pastis-benchmark](https://github.com/VSainteuf/pastis-benchmark)
diff --git a/examples/UTAE/src/backbones/convlstm.py b/examples/UTAE/src/backbones/convlstm.py
new file mode 100644
index 000000000..aacf3ba6e
--- /dev/null
+++ b/examples/UTAE/src/backbones/convlstm.py
@@ -0,0 +1,262 @@
+"""
+ConvLSTM Implementation (Paddle Version)
+Converted to PaddlePaddle
+"""
+import paddle
+import paddle.nn as nn
+
+
+class ConvLSTMCell(nn.Layer):
+ """
+ Initialize ConvLSTM cell.
+
+ Parameters
+ ----------
+ input_size: (int, int)
+ Height and width of input tensor as (height, width).
+ input_dim: int
+ Number of channels of input tensor.
+ hidden_dim: int
+ Number of channels of hidden state.
+ kernel_size: (int, int)
+ Size of the convolutional kernel.
+ bias: bool
+ Whether or not to add the bias.
+ """
+
+ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
+
+ super(ConvLSTMCell, self).__init__()
+ self.height, self.width = input_size
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+
+ self.kernel_size = kernel_size
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
+ self.bias = bias
+
+ self.conv = nn.Conv2D(
+ in_channels=self.input_dim + self.hidden_dim,
+ out_channels=4 * self.hidden_dim,
+ kernel_size=self.kernel_size,
+ padding=self.padding,
+ bias_attr=self.bias,
+ )
+
+ def forward(self, input_tensor, cur_state):
+ h_cur, c_cur = cur_state
+
+ combined = paddle.concat(
+ [input_tensor, h_cur], axis=1
+ ) # concatenate along channel axis
+
+ combined_conv = self.conv(combined)
+ cc_i, cc_f, cc_o, cc_g = paddle.split(combined_conv, self.hidden_dim, axis=1)
+ i = paddle.nn.functional.sigmoid(cc_i)
+ f = paddle.nn.functional.sigmoid(cc_f)
+ o = paddle.nn.functional.sigmoid(cc_o)
+ g = paddle.nn.functional.tanh(cc_g)
+
+ c_next = f * c_cur + i * g
+ h_next = o * paddle.nn.functional.tanh(c_next)
+
+ return h_next, c_next
+
+ def init_hidden(self, batch_size):
+ return (
+ paddle.zeros([batch_size, self.hidden_dim, self.height, self.width]),
+ paddle.zeros([batch_size, self.hidden_dim, self.height, self.width]),
+ )
+
+
+class ConvLSTM(nn.Layer):
+ def __init__(
+ self,
+ input_size,
+ input_dim,
+ hidden_dim,
+ kernel_size,
+ num_layers=1,
+ batch_first=True,
+ bias=True,
+ return_all_layers=False,
+ ):
+ super(ConvLSTM, self).__init__()
+
+ self._check_kernel_size_consistency(kernel_size)
+
+ # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
+ kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
+ hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
+ if not len(kernel_size) == len(hidden_dim) == num_layers:
+ raise ValueError("Inconsistent list length.")
+
+ self.height, self.width = input_size
+
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.kernel_size = kernel_size
+ self.num_layers = num_layers
+ self.batch_first = batch_first
+ self.bias = bias
+ self.return_all_layers = return_all_layers
+
+ cell_list = []
+ for i in range(0, self.num_layers):
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
+
+ cell_list.append(
+ ConvLSTMCell(
+ input_size=(self.height, self.width),
+ input_dim=cur_input_dim,
+ hidden_dim=self.hidden_dim[i],
+ kernel_size=self.kernel_size[i],
+ bias=self.bias,
+ )
+ )
+
+ self.cell_list = nn.LayerList(cell_list)
+
+ def forward(self, input_tensor, hidden_state=None, pad_mask=None):
+ """
+ Parameters
+ ----------
+ input_tensor: todo
+ 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
+ hidden_state: todo
+ None. todo implement stateful
+ pad_mask (b, t)
+ Returns
+ -------
+ last_state_list, layer_output
+ """
+ if not self.batch_first:
+ # (t, b, c, h, w) -> (b, t, c, h, w)
+ input_tensor = input_tensor.transpose([1, 0, 2, 3, 4])
+
+ # Implement stateful ConvLSTM
+ if hidden_state is not None:
+ raise NotImplementedError()
+ else:
+ hidden_state = self._init_hidden(batch_size=input_tensor.shape[0])
+
+ layer_output_list = []
+ last_state_list = []
+
+ seq_len = input_tensor.shape[1]
+ cur_layer_input = input_tensor
+
+ for layer_idx in range(self.num_layers):
+ h, c = hidden_state[layer_idx]
+ output_inner = []
+ for t in range(seq_len):
+ if pad_mask is not None:
+ # Only process non-padded timesteps
+ mask = ~pad_mask[:, t] # B
+ if mask.any():
+ h_t, c_t = self.cell_list[layer_idx](
+ cur_layer_input[mask, t, :, :, :],
+ cur_state=(h[mask], c[mask]),
+ )
+ h[mask] = h_t
+ c[mask] = c_t
+ else:
+ h, c = self.cell_list[layer_idx](
+ input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
+ )
+ output_inner.append(h)
+
+ layer_output = paddle.stack(output_inner, axis=1)
+ cur_layer_input = layer_output
+
+ layer_output_list.append(layer_output)
+ last_state_list.append([h, c])
+
+ if not self.return_all_layers:
+ layer_output_list = layer_output_list[-1:]
+ last_state_list = last_state_list[-1:]
+
+ return layer_output_list, last_state_list
+
+ def _init_hidden(self, batch_size):
+ init_states = []
+ for i in range(self.num_layers):
+ init_states.append(self.cell_list[i].init_hidden(batch_size))
+ return init_states
+
+ @staticmethod
+ def _check_kernel_size_consistency(kernel_size):
+ if not (
+ isinstance(kernel_size, tuple)
+ or (
+ isinstance(kernel_size, list)
+ and all([isinstance(elem, tuple) for elem in kernel_size])
+ )
+ ):
+ raise ValueError("`kernel_size` must be tuple or list of tuples")
+
+ @staticmethod
+ def _extend_for_multilayer(param, num_layers):
+ if not isinstance(param, list):
+ param = [param] * num_layers
+ return param
+
+
+class BConvLSTM(nn.Layer):
+ """Bidirectional ConvLSTM"""
+
+ def __init__(
+ self,
+ input_size,
+ input_dim,
+ hidden_dim,
+ kernel_size,
+ num_layers=1,
+ batch_first=True,
+ bias=True,
+ ):
+ super(BConvLSTM, self).__init__()
+
+ self.forward_net = ConvLSTM(
+ input_size,
+ input_dim,
+ hidden_dim,
+ kernel_size,
+ num_layers,
+ batch_first,
+ bias,
+ return_all_layers=False,
+ )
+ self.backward_net = ConvLSTM(
+ input_size,
+ input_dim,
+ hidden_dim,
+ kernel_size,
+ num_layers,
+ batch_first,
+ bias,
+ return_all_layers=False,
+ )
+
+ def forward(self, input_tensor, pad_mask=None):
+ # Forward pass
+ forward_output, _ = self.forward_net(input_tensor, pad_mask=pad_mask)
+
+ # Backward pass - reverse the sequence
+ reversed_input = paddle.flip(input_tensor, [1]) # Reverse time dimension
+ if pad_mask is not None:
+ reversed_pad_mask = paddle.flip(pad_mask, [1])
+ else:
+ reversed_pad_mask = None
+
+ backward_output, _ = self.backward_net(
+ reversed_input, pad_mask=reversed_pad_mask
+ )
+ backward_output = [
+ paddle.flip(output, [1]) for output in backward_output
+ ] # Reverse back
+
+ # Concatenate forward and backward outputs
+ combined_output = paddle.concat([forward_output[0], backward_output[0]], axis=2)
+
+ return combined_output
diff --git a/examples/UTAE/src/backbones/ltae.py b/examples/UTAE/src/backbones/ltae.py
new file mode 100644
index 000000000..6c0417818
--- /dev/null
+++ b/examples/UTAE/src/backbones/ltae.py
@@ -0,0 +1,222 @@
+import copy
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from src.backbones.positional_encoding import PositionalEncoder
+
+
+class LTAE2d(nn.Layer):
+ """
+ Lightweight Temporal Attention Encoder (L-TAE) for image time series.
+ Attention-based sequence encoding that maps a sequence of images to a single feature map.
+ A shared L-TAE is applied to all pixel positions of the image sequence.
+ Args:
+ in_channels (int): Number of channels of the input embeddings.
+ n_head (int): Number of attention heads.
+ d_k (int): Dimension of the key and query vectors.
+ mlp (List[int]): Widths of the layers of the MLP that processes the concatenated outputs of the attention heads.
+ dropout (float): dropout
+ d_model (int, optional): If specified, the input tensors will first processed by a fully connected layer
+ to project them into a feature space of dimension d_model.
+ T (int): Period to use for the positional encoding.
+ return_att (bool): If true, the module returns the attention masks along with the embeddings (default False)
+ positional_encoding (bool): If False, no positional encoding is used (default True).
+ """
+
+ def __init__(
+ self,
+ in_channels=128,
+ n_head=16,
+ d_k=4,
+ mlp=[256, 128],
+ dropout=0.2,
+ d_model=256,
+ T=1000,
+ return_att=False,
+ positional_encoding=True,
+ ):
+
+ super(LTAE2d, self).__init__()
+ self.in_channels = in_channels
+ self.mlp = copy.deepcopy(mlp)
+ self.return_att = return_att
+ self.n_head = n_head
+
+ if d_model is not None:
+ self.d_model = d_model
+ self.inconv = nn.Conv1D(in_channels, d_model, 1)
+ else:
+ self.d_model = in_channels
+ self.inconv = None
+ assert self.mlp[0] == self.d_model
+
+ if positional_encoding:
+ self.positional_encoder = PositionalEncoder(
+ self.d_model // n_head, T=T, repeat=n_head
+ )
+ else:
+ self.positional_encoder = None
+
+ self.attention_heads = MultiHeadAttention(
+ n_head=n_head, d_k=d_k, d_in=self.d_model
+ )
+ self.in_norm = nn.GroupNorm(
+ num_groups=n_head,
+ num_channels=self.in_channels,
+ )
+ self.out_norm = nn.GroupNorm(
+ num_groups=n_head,
+ num_channels=mlp[-1],
+ )
+
+ layers = []
+ for i in range(len(self.mlp) - 1):
+ layers.extend(
+ [
+ nn.Linear(self.mlp[i], self.mlp[i + 1]),
+ nn.BatchNorm1D(self.mlp[i + 1]),
+ nn.ReLU(),
+ ]
+ )
+
+ self.mlp = nn.Sequential(*layers)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False):
+ sz_b, seq_len, d, h, w = x.shape
+ if pad_mask is not None:
+ pad_mask = (
+ pad_mask.unsqueeze(-1).tile([1, 1, h]).unsqueeze(-1).tile([1, 1, 1, w])
+ ) # BxTxHxW
+ pad_mask = pad_mask.transpose([0, 2, 3, 1]).reshape([sz_b * h * w, seq_len])
+
+ out = x.transpose([0, 3, 4, 1, 2]).reshape([sz_b * h * w, seq_len, d])
+ out = self.in_norm(out.transpose([0, 2, 1])).transpose([0, 2, 1])
+
+ if self.inconv is not None:
+ out = self.inconv(out.transpose([0, 2, 1])).transpose([0, 2, 1])
+
+ if self.positional_encoder is not None:
+ bp = (
+ batch_positions.unsqueeze(-1)
+ .tile([1, 1, h])
+ .unsqueeze(-1)
+ .tile([1, 1, 1, w])
+ ) # BxTxHxW
+ bp = bp.transpose([0, 2, 3, 1]).reshape([sz_b * h * w, seq_len])
+ out = out + self.positional_encoder(bp)
+
+ out, attn = self.attention_heads(out, pad_mask=pad_mask)
+
+ out = out.transpose([1, 0, 2]).reshape([sz_b * h * w, -1]) # Concatenate heads
+ out = self.dropout(self.mlp(out))
+ out = self.out_norm(out) if self.out_norm is not None else out
+ out = out.reshape([sz_b, h, w, -1]).transpose([0, 3, 1, 2])
+
+ attn = attn.reshape([self.n_head, sz_b, h, w, seq_len]).transpose(
+ [0, 1, 4, 2, 3]
+ ) # head x b x t x h x w
+
+ if self.return_att:
+ return out, attn
+ else:
+ return out
+
+
+class MultiHeadAttention(nn.Layer):
+ """Multi-Head Attention module
+ Modified from github.com/jadore801120/attention-is-all-you-need-pytorch
+ """
+
+ def __init__(self, n_head, d_k, d_in):
+ super().__init__()
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_in = d_in
+
+ self.Q = self.create_parameter(
+ shape=[n_head, d_k],
+ dtype="float32",
+ default_initializer=nn.initializer.Normal(mean=0.0, std=np.sqrt(2.0 / d_k)),
+ )
+
+ self.fc1_k = nn.Linear(d_in, n_head * d_k)
+ # Initialize weights
+ nn.initializer.Normal(mean=0.0, std=np.sqrt(2.0 / d_k))(self.fc1_k.weight)
+
+ self.attention = ScaledDotProductAttention(
+ temperature=float(np.power(d_k, 0.5))
+ )
+
+ def forward(self, v, pad_mask=None, return_comp=False):
+ d_k, d_in, n_head = self.d_k, self.d_in, self.n_head
+ sz_b, seq_len, _ = v.shape
+
+ q = paddle.stack([self.Q for _ in range(sz_b)], axis=1).reshape(
+ [-1, d_k]
+ ) # (n*b) x d_k
+
+ k = self.fc1_k(v).reshape([sz_b, seq_len, n_head, d_k])
+ k = k.transpose([2, 0, 1, 3]).reshape([-1, seq_len, d_k]) # (n*b) x lk x dk
+
+ if pad_mask is not None:
+ pad_mask = pad_mask.tile(
+ [n_head, 1]
+ ) # replicate pad_mask for each head (nxb) x lk
+
+ # Split v into n_head chunks
+ chunk_size = v.shape[-1] // n_head
+ v_chunks = []
+ for i in range(n_head):
+ start_idx = i * chunk_size
+ end_idx = (i + 1) * chunk_size
+ v_chunks.append(v[:, :, start_idx:end_idx])
+ v = paddle.stack(v_chunks).reshape([n_head * sz_b, seq_len, -1])
+
+ if return_comp:
+ output, attn, comp = self.attention(
+ q, k, v, pad_mask=pad_mask, return_comp=return_comp
+ )
+ else:
+ output, attn = self.attention(
+ q, k, v, pad_mask=pad_mask, return_comp=return_comp
+ )
+ attn = attn.reshape([n_head, sz_b, 1, seq_len])
+ attn = attn.squeeze(axis=2)
+
+ output = output.reshape([n_head, sz_b, 1, d_in // n_head])
+ output = output.squeeze(axis=2)
+
+ if return_comp:
+ return output, attn, comp
+ else:
+ return output, attn
+
+
+class ScaledDotProductAttention(nn.Layer):
+ """Scaled Dot-Product Attention
+ Modified from github.com/jadore801120/attention-is-all-you-need-pytorch
+ """
+
+ def __init__(self, temperature, attn_dropout=0.1):
+ super().__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(attn_dropout)
+ self.softmax = nn.Softmax(axis=2)
+
+ def forward(self, q, k, v, pad_mask=None, return_comp=False):
+ attn = paddle.matmul(q.unsqueeze(1), k.transpose([0, 2, 1]))
+ attn = attn / self.temperature
+ if pad_mask is not None:
+ attn = paddle.where(pad_mask.unsqueeze(1), paddle.to_tensor(-1e3), attn)
+ if return_comp:
+ comp = attn
+ attn = self.softmax(attn)
+ attn = self.dropout(attn)
+ output = paddle.matmul(attn, v)
+
+ if return_comp:
+ return output, attn, comp
+ else:
+ return output, attn
diff --git a/examples/UTAE/src/backbones/positional_encoding.py b/examples/UTAE/src/backbones/positional_encoding.py
new file mode 100644
index 000000000..2b944b56d
--- /dev/null
+++ b/examples/UTAE/src/backbones/positional_encoding.py
@@ -0,0 +1,39 @@
+import paddle
+import paddle.nn as nn
+
+
+class PositionalEncoder(nn.Layer):
+ def __init__(self, d, T=1000, repeat=None, offset=0):
+ super(PositionalEncoder, self).__init__()
+ self.d = d
+ self.T = T
+ self.repeat = repeat
+ self.denom = paddle.pow(
+ paddle.to_tensor(T, dtype="float32"),
+ 2 * (paddle.arange(offset, offset + d, dtype="float32") // 2) / d,
+ )
+ self.updated_location = False
+
+ def forward(self, batch_positions):
+ if not self.updated_location:
+ # Move to same device as input
+ if hasattr(batch_positions, "place"):
+ self.denom = (
+ self.denom.cuda()
+ if "gpu" in str(batch_positions.place)
+ else self.denom.cpu()
+ )
+ self.updated_location = True
+
+ sinusoid_table = (
+ batch_positions[:, :, None].astype("float32") / self.denom[None, None, :]
+ ) # B x T x C
+ sinusoid_table[:, :, 0::2] = paddle.sin(sinusoid_table[:, :, 0::2]) # dim 2i
+ sinusoid_table[:, :, 1::2] = paddle.cos(sinusoid_table[:, :, 1::2]) # dim 2i+1
+
+ if self.repeat is not None:
+ sinusoid_table = paddle.concat(
+ [sinusoid_table for _ in range(self.repeat)], axis=-1
+ )
+
+ return sinusoid_table
diff --git a/examples/UTAE/src/backbones/utae.py b/examples/UTAE/src/backbones/utae.py
new file mode 100644
index 000000000..32ce5ce99
--- /dev/null
+++ b/examples/UTAE/src/backbones/utae.py
@@ -0,0 +1,615 @@
+"""
+U-TAE Implementation (Paddle Version)
+Converted to PaddlePaddle
+"""
+import paddle
+import paddle.nn as nn
+from src.backbones.convlstm import BConvLSTM
+from src.backbones.convlstm import ConvLSTM
+from src.backbones.ltae import LTAE2d
+
+
+class UTAE(nn.Layer):
+ """
+ U-TAE architecture for spatio-temporal encoding of satellite image time series.
+ Args:
+ input_dim (int): Number of channels in the input images.
+ encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder.
+ This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1)
+ in the architecture.
+ The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution.
+ decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of
+ channels should be given is also from top to bottom. If this argument is not specified the decoder
+ will have the same configuration as the encoder.
+ out_conv (List[int]): Number of channels of the successive convolutions for the
+ str_conv_k (int): Kernel size of the strided up and down convolutions.
+ str_conv_s (int): Stride of the strided up and down convolutions.
+ str_conv_p (int): Padding of the strided up and down convolutions.
+ agg_mode (str): Aggregation mode for the skip connections. Can either be:
+ - att_group (default) : Attention weighted temporal average, using the same
+ channel grouping strategy as in the LTAE. The attention masks are bilinearly
+ resampled to the resolution of the skipped feature maps.
+ - att_mean : Attention weighted temporal average,
+ using the average attention scores across heads for each date.
+ - mean : Temporal average excluding padded dates.
+ encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be:
+ - group : GroupNorm (default)
+ - batch : BatchNorm
+ - instance : InstanceNorm
+ n_head (int): Number of heads in LTAE.
+ d_model (int): Parameter of LTAE
+ d_k (int): Key-Query space dimension
+ encoder (bool): If true, the feature maps instead of the class scores are returned (default False)
+ return_maps (bool): If true, the feature maps instead of the class scores are returned (default False)
+ pad_value (float): Value used by the dataloader for temporal padding.
+ padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2D).
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ encoder_widths=[64, 64, 64, 128],
+ decoder_widths=[32, 32, 64, 128],
+ out_conv=[32, 20],
+ str_conv_k=4,
+ str_conv_s=2,
+ str_conv_p=1,
+ agg_mode="att_group",
+ encoder_norm="group",
+ n_head=16,
+ d_model=256,
+ d_k=4,
+ encoder=False,
+ return_maps=False,
+ pad_value=0,
+ padding_mode="reflect",
+ ):
+
+ super(UTAE, self).__init__()
+ self.n_stages = len(encoder_widths)
+ self.return_maps = return_maps
+ self.encoder_widths = encoder_widths
+ self.decoder_widths = decoder_widths
+ self.enc_dim = (
+ decoder_widths[0] if decoder_widths is not None else encoder_widths[0]
+ )
+ self.stack_dim = (
+ sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths)
+ )
+ self.pad_value = pad_value
+ self.encoder = encoder
+ if encoder:
+ self.return_maps = True
+
+ if decoder_widths is not None:
+ assert len(encoder_widths) == len(decoder_widths)
+ assert encoder_widths[-1] == decoder_widths[-1]
+ else:
+ decoder_widths = encoder_widths
+
+ self.in_conv = ConvBlock(
+ nkernels=[input_dim] + [encoder_widths[0], encoder_widths[0]],
+ pad_value=pad_value,
+ norm=encoder_norm,
+ padding_mode=padding_mode,
+ )
+ self.down_blocks = nn.LayerList(
+ [
+ DownConvBlock(
+ d_in=encoder_widths[i],
+ d_out=encoder_widths[i + 1],
+ k=str_conv_k,
+ s=str_conv_s,
+ p=str_conv_p,
+ pad_value=pad_value,
+ norm=encoder_norm,
+ padding_mode=padding_mode,
+ )
+ for i in range(self.n_stages - 1)
+ ]
+ )
+ self.up_blocks = nn.LayerList(
+ [
+ UpConvBlock(
+ d_in=decoder_widths[i],
+ d_out=decoder_widths[i - 1],
+ d_skip=encoder_widths[i - 1],
+ k=str_conv_k,
+ s=str_conv_s,
+ p=str_conv_p,
+ norm="batch",
+ padding_mode=padding_mode,
+ )
+ for i in range(self.n_stages - 1, 0, -1)
+ ]
+ )
+ self.temporal_encoder = LTAE2d(
+ in_channels=encoder_widths[-1],
+ d_model=d_model,
+ n_head=n_head,
+ mlp=[d_model, encoder_widths[-1]],
+ return_att=True,
+ d_k=d_k,
+ )
+ self.temporal_aggregator = Temporal_Aggregator(mode=agg_mode)
+ self.out_conv = ConvBlock(
+ nkernels=[decoder_widths[0]] + out_conv, padding_mode=padding_mode
+ )
+
+ def forward(self, input, batch_positions=None, return_att=False):
+ # Create pad mask by comparing with pad_value
+ # Use safe tensor comparison to avoid type issues
+ pad_value_tensor = paddle.to_tensor(self.pad_value, dtype=input.dtype)
+ comparison = paddle.equal(input, pad_value_tensor)
+
+ # Sequentially reduce dimensions using all()
+ mask_step1 = paddle.all(comparison, axis=-1) # Reduce last dim
+ mask_step2 = paddle.all(mask_step1, axis=-1) # Reduce second-to-last dim
+ pad_mask = paddle.all(mask_step2, axis=-1) # Reduce third-to-last dim (BxT)
+ out = self.in_conv.smart_forward(input)
+ feature_maps = [out]
+ # SPATIAL ENCODER
+ for i in range(self.n_stages - 1):
+ out = self.down_blocks[i].smart_forward(feature_maps[-1])
+ feature_maps.append(out)
+ # TEMPORAL ENCODER
+ out, att = self.temporal_encoder(
+ feature_maps[-1], batch_positions=batch_positions, pad_mask=pad_mask
+ )
+ # SPATIAL DECODER
+ if self.return_maps:
+ maps = [out]
+ for i in range(self.n_stages - 1):
+ skip = self.temporal_aggregator(
+ feature_maps[-(i + 2)], pad_mask=pad_mask, attn_mask=att
+ )
+ out = self.up_blocks[i](out, skip)
+ if self.return_maps:
+ maps.append(out)
+
+ if self.encoder:
+ return out, maps
+ else:
+ out = self.out_conv(out)
+ if return_att:
+ return out, att
+ if self.return_maps:
+ return out, maps
+ else:
+ return out
+
+
+class TemporallySharedBlock(nn.Layer):
+ """
+ Helper module for convolutional encoding blocks that are shared across a sequence.
+ This module adds the self.smart_forward() method the the block.
+ smart_forward will combine the batch and temporal dimension of an input tensor
+ if it is 5-D and apply the shared convolutions to all the (batch x temp) positions.
+ """
+
+ def __init__(self, pad_value=None):
+ super(TemporallySharedBlock, self).__init__()
+ self.out_shape = None
+ self.pad_value = pad_value
+
+ def smart_forward(self, input):
+ if len(input.shape) == 4:
+ return self.forward(input)
+ else:
+ b, t, c, h, w = input.shape
+
+ if self.pad_value is not None:
+ dummy = paddle.zeros(input.shape).astype("float32")
+ self.out_shape = self.forward(dummy.reshape([b * t, c, h, w])).shape
+
+ out = input.reshape([b * t, c, h, w])
+ if self.pad_value is not None:
+ pad_value_tensor = paddle.to_tensor(self.pad_value, dtype=out.dtype)
+ comparison = paddle.equal(out, pad_value_tensor)
+ mask_step1 = paddle.all(comparison, axis=-1) # Reduce last dim
+ mask_step2 = paddle.all(
+ mask_step1, axis=-1
+ ) # Reduce second-to-last dim
+ pad_mask = paddle.all(mask_step2, axis=-1) # Reduce third-to-last dim
+ if pad_mask.any():
+ temp = paddle.ones(self.out_shape) * self.pad_value
+ temp[~pad_mask] = self.forward(out[~pad_mask])
+ out = temp
+ else:
+ out = self.forward(out)
+ else:
+ out = self.forward(out)
+ _, c, h, w = out.shape
+ out = out.reshape([b, t, c, h, w])
+ return out
+
+
+class ConvLayer(nn.Layer):
+ def __init__(
+ self,
+ nkernels,
+ norm="batch",
+ k=3,
+ s=1,
+ p=1,
+ n_groups=4,
+ last_relu=True,
+ padding_mode="reflect",
+ ):
+ super(ConvLayer, self).__init__()
+ layers = []
+ if norm == "batch":
+ nl = nn.BatchNorm2D
+ elif norm == "instance":
+ nl = nn.InstanceNorm2D
+ elif norm == "group":
+ nl = lambda num_feats: nn.GroupNorm(
+ num_channels=num_feats,
+ num_groups=n_groups,
+ )
+ else:
+ nl = None
+ for i in range(len(nkernels) - 1):
+ layers.append(
+ nn.Conv2D(
+ in_channels=nkernels[i],
+ out_channels=nkernels[i + 1],
+ kernel_size=k,
+ padding=p,
+ stride=s,
+ padding_mode=padding_mode,
+ )
+ )
+ if nl is not None:
+ layers.append(nl(nkernels[i + 1]))
+
+ if last_relu:
+ layers.append(nn.ReLU())
+ elif i < len(nkernels) - 2:
+ layers.append(nn.ReLU())
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, input):
+ return self.conv(input)
+
+
+class ConvBlock(TemporallySharedBlock):
+ def __init__(
+ self,
+ nkernels,
+ pad_value=None,
+ norm="batch",
+ last_relu=True,
+ padding_mode="reflect",
+ ):
+ super(ConvBlock, self).__init__(pad_value=pad_value)
+ self.conv = ConvLayer(
+ nkernels=nkernels,
+ norm=norm,
+ last_relu=last_relu,
+ padding_mode=padding_mode,
+ )
+
+ def forward(self, input):
+ return self.conv(input)
+
+
+class DownConvBlock(TemporallySharedBlock):
+ def __init__(
+ self,
+ d_in,
+ d_out,
+ k,
+ s,
+ p,
+ pad_value=None,
+ norm="batch",
+ padding_mode="reflect",
+ ):
+ super(DownConvBlock, self).__init__(pad_value=pad_value)
+ self.down = ConvLayer(
+ nkernels=[d_in, d_in],
+ norm=norm,
+ k=k,
+ s=s,
+ p=p,
+ padding_mode=padding_mode,
+ )
+ self.conv1 = ConvLayer(
+ nkernels=[d_in, d_out],
+ norm=norm,
+ padding_mode=padding_mode,
+ )
+ self.conv2 = ConvLayer(
+ nkernels=[d_out, d_out],
+ norm=norm,
+ padding_mode=padding_mode,
+ )
+
+ def forward(self, input):
+ out = self.down(input)
+ out = self.conv1(out)
+ out = out + self.conv2(out)
+ return out
+
+
+class UpConvBlock(nn.Layer):
+ def __init__(
+ self, d_in, d_out, k, s, p, norm="batch", d_skip=None, padding_mode="reflect"
+ ):
+ super(UpConvBlock, self).__init__()
+ d = d_out if d_skip is None else d_skip
+ self.skip_conv = nn.Sequential(
+ nn.Conv2D(in_channels=d, out_channels=d, kernel_size=1),
+ nn.BatchNorm2D(d),
+ nn.ReLU(),
+ )
+ self.up = nn.Sequential(
+ nn.Conv2DTranspose(
+ in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p
+ ),
+ nn.BatchNorm2D(d_out),
+ nn.ReLU(),
+ )
+ self.conv1 = ConvLayer(
+ nkernels=[d_out + d, d_out], norm=norm, padding_mode=padding_mode
+ )
+ self.conv2 = ConvLayer(
+ nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode
+ )
+
+ def forward(self, input, skip):
+ out = self.up(input)
+ out = paddle.concat([out, self.skip_conv(skip)], axis=1)
+ out = self.conv1(out)
+ out = out + self.conv2(out)
+ return out
+
+
+class Temporal_Aggregator(nn.Layer):
+ def __init__(self, mode="mean"):
+ super(Temporal_Aggregator, self).__init__()
+ self.mode = mode
+
+ def forward(self, x, pad_mask=None, attn_mask=None):
+ if pad_mask is not None and pad_mask.any():
+ if self.mode == "att_group":
+ n_heads, b, t, h, w = attn_mask.shape
+ attn = attn_mask.reshape([n_heads * b, t, h, w])
+
+ if x.shape[-2] > w:
+ attn = nn.functional.interpolate(
+ attn, size=x.shape[-2:], mode="bilinear", align_corners=False
+ )
+ else:
+ attn = nn.functional.avg_pool2d(attn, kernel_size=w // x.shape[-2])
+
+ attn = attn.reshape([n_heads, b, t, *x.shape[-2:]])
+ attn = attn * (~pad_mask).astype("float32")[None, :, :, None, None]
+
+ # Split x into n_heads chunks along axis 2
+ chunk_size = x.shape[2] // n_heads
+ x_chunks = []
+ for i in range(n_heads):
+ start_idx = i * chunk_size
+ end_idx = (i + 1) * chunk_size
+ x_chunks.append(x[:, :, start_idx:end_idx, :, :])
+ out = paddle.stack(x_chunks) # hxBxTxC/hxHxW
+ out = attn[:, :, :, None, :, :] * out
+ out = out.sum(axis=2) # sum on temporal dim -> hxBxC/hxHxW
+ out = paddle.concat([group for group in out], axis=1) # -> BxCxHxW
+ return out
+ elif self.mode == "att_mean":
+ attn = attn_mask.mean(axis=0) # average over heads -> BxTxHxW
+ attn = nn.functional.interpolate(
+ attn, size=x.shape[-2:], mode="bilinear", align_corners=False
+ )
+ attn = attn * (~pad_mask).astype("float32")[:, :, None, None]
+ out = (x * attn[:, :, None, :, :]).sum(axis=1)
+ return out
+ elif self.mode == "mean":
+ out = x * (~pad_mask).astype("float32")[:, :, None, None, None]
+ out = out.sum(axis=1) / (~pad_mask).sum(axis=1)[:, None, None, None]
+ return out
+ else:
+ if self.mode == "att_group":
+ n_heads, b, t, h, w = attn_mask.shape
+ attn = attn_mask.reshape([n_heads * b, t, h, w])
+ if x.shape[-2] > w:
+ attn = nn.functional.interpolate(
+ attn, size=x.shape[-2:], mode="bilinear", align_corners=False
+ )
+ else:
+ attn = nn.functional.avg_pool2d(attn, kernel_size=w // x.shape[-2])
+ attn = attn.reshape([n_heads, b, t, *x.shape[-2:]])
+ # Split x into n_heads chunks along axis 2
+ chunk_size = x.shape[2] // n_heads
+ x_chunks = []
+ for i in range(n_heads):
+ start_idx = i * chunk_size
+ end_idx = (i + 1) * chunk_size
+ x_chunks.append(x[:, :, start_idx:end_idx, :, :])
+ out = paddle.stack(x_chunks) # hxBxTxC/hxHxW
+ out = attn[:, :, :, None, :, :] * out
+ out = out.sum(axis=2) # sum on temporal dim -> hxBxC/hxHxW
+ out = paddle.concat([group for group in out], axis=1) # -> BxCxHxW
+ return out
+ elif self.mode == "att_mean":
+ attn = attn_mask.mean(axis=0) # average over heads -> BxTxHxW
+ attn = nn.functional.interpolate(
+ attn, size=x.shape[-2:], mode="bilinear", align_corners=False
+ )
+ out = (x * attn[:, :, None, :, :]).sum(axis=1)
+ return out
+ elif self.mode == "mean":
+ return x.mean(axis=1)
+
+
+class RecUNet(nn.Layer):
+ """Recurrent U-Net architecture. Similar to the U-TAE architecture but
+ the L-TAE is replaced by a recurrent network
+ and temporal averages are computed for the skip connections."""
+
+ def __init__(
+ self,
+ input_dim,
+ encoder_widths=[64, 64, 64, 128],
+ decoder_widths=[32, 32, 64, 128],
+ out_conv=[32, 20],
+ str_conv_k=4,
+ str_conv_s=2,
+ str_conv_p=1,
+ temporal="lstm",
+ input_size=128,
+ encoder_norm="group",
+ hidden_dim=128,
+ encoder=False,
+ padding_mode="reflect",
+ pad_value=0,
+ ):
+ super(RecUNet, self).__init__()
+ self.n_stages = len(encoder_widths)
+ self.temporal = temporal
+ self.encoder_widths = encoder_widths
+ self.decoder_widths = decoder_widths
+ self.enc_dim = (
+ decoder_widths[0] if decoder_widths is not None else encoder_widths[0]
+ )
+ self.stack_dim = (
+ sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths)
+ )
+ self.pad_value = pad_value
+
+ self.encoder = encoder
+ if encoder:
+ self.return_maps = True
+ else:
+ self.return_maps = False
+
+ if decoder_widths is not None:
+ assert len(encoder_widths) == len(decoder_widths)
+ assert encoder_widths[-1] == decoder_widths[-1]
+ else:
+ decoder_widths = encoder_widths
+
+ self.in_conv = ConvBlock(
+ nkernels=[input_dim] + [encoder_widths[0], encoder_widths[0]],
+ pad_value=pad_value,
+ norm=encoder_norm,
+ )
+
+ self.down_blocks = nn.LayerList(
+ [
+ DownConvBlock(
+ d_in=encoder_widths[i],
+ d_out=encoder_widths[i + 1],
+ k=str_conv_k,
+ s=str_conv_s,
+ p=str_conv_p,
+ pad_value=pad_value,
+ norm=encoder_norm,
+ padding_mode=padding_mode,
+ )
+ for i in range(self.n_stages - 1)
+ ]
+ )
+ self.up_blocks = nn.LayerList(
+ [
+ UpConvBlock(
+ d_in=decoder_widths[i],
+ d_out=decoder_widths[i - 1],
+ d_skip=encoder_widths[i - 1],
+ k=str_conv_k,
+ s=str_conv_s,
+ p=str_conv_p,
+ norm=encoder_norm,
+ padding_mode=padding_mode,
+ )
+ for i in range(self.n_stages - 1, 0, -1)
+ ]
+ )
+ self.temporal_aggregator = Temporal_Aggregator(mode="mean")
+
+ if temporal == "mean":
+ self.temporal_encoder = Temporal_Aggregator(mode="mean")
+ elif temporal == "lstm":
+ size = int(input_size / str_conv_s ** (self.n_stages - 1))
+ self.temporal_encoder = ConvLSTM(
+ input_dim=encoder_widths[-1],
+ input_size=(size, size),
+ hidden_dim=hidden_dim,
+ kernel_size=(3, 3),
+ )
+ self.out_convlstm = nn.Conv2D(
+ in_channels=hidden_dim,
+ out_channels=encoder_widths[-1],
+ kernel_size=3,
+ padding=1,
+ )
+ elif temporal == "blstm":
+ size = int(input_size / str_conv_s ** (self.n_stages - 1))
+ self.temporal_encoder = BConvLSTM(
+ input_dim=encoder_widths[-1],
+ input_size=(size, size),
+ hidden_dim=hidden_dim,
+ kernel_size=(3, 3),
+ )
+ self.out_convlstm = nn.Conv2D(
+ in_channels=2 * hidden_dim,
+ out_channels=encoder_widths[-1],
+ kernel_size=3,
+ padding=1,
+ )
+ elif temporal == "mono":
+ self.temporal_encoder = None
+ self.out_conv = ConvBlock(
+ nkernels=[decoder_widths[0]] + out_conv, padding_mode=padding_mode
+ )
+
+ def forward(self, input, batch_positions=None):
+ pad_mask = (
+ (input == self.pad_value).all(axis=-1).all(axis=-1).all(axis=-1)
+ ) # BxT pad mask
+
+ out = self.in_conv.smart_forward(input)
+
+ feature_maps = [out]
+ # ENCODER
+ for i in range(self.n_stages - 1):
+ out = self.down_blocks[i].smart_forward(feature_maps[-1])
+ feature_maps.append(out)
+
+ # Temporal encoder
+ if self.temporal == "mean":
+ out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask)
+ elif self.temporal == "lstm":
+ _, out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask)
+ out = out[0][1] # take last cell state as embedding
+ out = self.out_convlstm(out)
+ elif self.temporal == "blstm":
+ out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask)
+ out = self.out_convlstm(out)
+ elif self.temporal == "mono":
+ out = feature_maps[-1]
+
+ if self.return_maps:
+ maps = [out]
+ for i in range(self.n_stages - 1):
+ if self.temporal != "mono":
+ skip = self.temporal_aggregator(
+ feature_maps[-(i + 2)], pad_mask=pad_mask
+ )
+ else:
+ skip = feature_maps[-(i + 2)]
+ out = self.up_blocks[i](out, skip)
+ if self.return_maps:
+ maps.append(out)
+
+ if self.encoder:
+ return out, maps
+ else:
+ out = self.out_conv(out)
+ if self.return_maps:
+ return out, maps
+ else:
+ return out
diff --git a/examples/UTAE/src/dataset.py b/examples/UTAE/src/dataset.py
new file mode 100644
index 000000000..c264d36a1
--- /dev/null
+++ b/examples/UTAE/src/dataset.py
@@ -0,0 +1,291 @@
+import json
+import os
+from datetime import datetime
+
+import geopandas as gpd
+import numpy as np
+import paddle
+import paddle.io as pio
+import pandas as pd
+
+
+class PASTIS_Dataset(pio.Dataset):
+ def __init__(
+ self,
+ folder,
+ norm=True,
+ target="semantic",
+ cache=False,
+ mem16=False,
+ folds=None,
+ reference_date="2018-09-01",
+ class_mapping=None,
+ mono_date=None,
+ sats=["S2"],
+ ):
+
+ super(PASTIS_Dataset, self).__init__()
+ self.folder = folder
+ self.norm = norm
+ self.reference_date = datetime(*map(int, reference_date.split("-")))
+ self.cache = cache
+ self.mem16 = mem16
+ self.mono_date = None
+ if mono_date is not None:
+ self.mono_date = (
+ datetime(*map(int, mono_date.split("-")))
+ if "-" in mono_date
+ else int(mono_date)
+ )
+ self.memory = {}
+ self.memory_dates = {}
+ self.class_mapping = (
+ np.vectorize(lambda x: class_mapping[x])
+ if class_mapping is not None
+ else class_mapping
+ )
+ self.target = target
+ self.sats = sats
+
+ # Get metadata
+ print("Reading patch metadata . . .")
+ self.meta_patch = gpd.read_file(os.path.join(folder, "metadata.geojson"))
+ self.meta_patch.index = self.meta_patch["ID_PATCH"].astype(int)
+ self.meta_patch.sort_index(inplace=True)
+
+ self.date_tables = {s: None for s in sats}
+ self.date_range = np.array(range(-200, 600))
+ for s in sats:
+ dates = self.meta_patch["dates-{}".format(s)]
+ date_table = pd.DataFrame(
+ index=self.meta_patch.index, columns=self.date_range, dtype=int
+ )
+ for pid, date_seq in dates.items():
+ if type(date_seq) == str:
+ date_seq = json.loads(date_seq)
+ d = pd.DataFrame().from_dict(date_seq, orient="index")
+ d = d[0].apply(
+ lambda x: (
+ datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]))
+ - self.reference_date
+ ).days
+ )
+ date_table.loc[pid, d.values] = 1
+ date_table = date_table.fillna(0)
+ self.date_tables[s] = {
+ index: np.array(list(d.values()))
+ for index, d in date_table.to_dict(orient="index").items()
+ }
+
+ print("Done.")
+
+ # Select Fold samples
+ if folds is not None:
+ self.meta_patch = pd.concat(
+ [self.meta_patch[self.meta_patch["Fold"] == f] for f in folds]
+ )
+
+ self.len = self.meta_patch.shape[0]
+ self.id_patches = self.meta_patch.index
+
+ # Get normalisation values
+ if norm:
+ self.norm = {}
+ for s in self.sats:
+ with open(
+ os.path.join(folder, "NORM_{}_patch.json".format(s)), "r"
+ ) as file:
+ normvals = json.loads(file.read())
+ selected_folds = folds if folds is not None else range(1, 6)
+ means = [normvals["Fold_{}".format(f)]["mean"] for f in selected_folds]
+ stds = [normvals["Fold_{}".format(f)]["std"] for f in selected_folds]
+ self.norm[s] = np.stack(means).mean(axis=0), np.stack(stds).mean(axis=0)
+ self.norm[s] = (
+ paddle.to_tensor(self.norm[s][0], dtype="float32"),
+ paddle.to_tensor(self.norm[s][1], dtype="float32"),
+ )
+ else:
+ self.norm = None
+ print("Dataset ready.")
+
+ def __len__(self):
+ return self.len
+
+ def get_dates(self, id_patch, sat):
+ return self.date_range[np.where(self.date_tables[sat][id_patch] == 1)[0]]
+
+ def __getitem__(self, item):
+ id_patch = self.id_patches[item]
+
+ # Retrieve and prepare satellite data
+ if not self.cache or item not in self.memory.keys():
+ data = {
+ satellite: np.load(
+ os.path.join(
+ self.folder,
+ "DATA_{}".format(satellite),
+ "{}_{}.npy".format(satellite, id_patch),
+ )
+ ).astype(np.float32)
+ for satellite in self.sats
+ } # T x C x H x W arrays
+ data = {s: paddle.to_tensor(a) for s, a in data.items()}
+
+ if self.norm is not None:
+ data = {
+ s: (d - self.norm[s][0][None, :, None, None])
+ / self.norm[s][1][None, :, None, None]
+ for s, d in data.items()
+ }
+
+ if self.target == "semantic":
+ target = np.load(
+ os.path.join(
+ self.folder, "ANNOTATIONS", "TARGET_{}.npy".format(id_patch)
+ )
+ )
+ target = paddle.to_tensor(target[0].astype(int))
+
+ if self.class_mapping is not None:
+ target = self.class_mapping(target)
+
+ elif self.target == "instance":
+ heatmap = np.load(
+ os.path.join(
+ self.folder,
+ "INSTANCE_ANNOTATIONS",
+ "HEATMAP_{}.npy".format(id_patch),
+ )
+ )
+
+ instance_ids = np.load(
+ os.path.join(
+ self.folder,
+ "INSTANCE_ANNOTATIONS",
+ "INSTANCES_{}.npy".format(id_patch),
+ )
+ )
+ pixel_to_object_mapping = np.load(
+ os.path.join(
+ self.folder,
+ "INSTANCE_ANNOTATIONS",
+ "ZONES_{}.npy".format(id_patch),
+ )
+ )
+
+ pixel_semantic_annotation = np.load(
+ os.path.join(
+ self.folder, "ANNOTATIONS", "TARGET_{}.npy".format(id_patch)
+ )
+ )
+
+ if self.class_mapping is not None:
+ pixel_semantic_annotation = self.class_mapping(
+ pixel_semantic_annotation[0]
+ )
+ else:
+ pixel_semantic_annotation = pixel_semantic_annotation[0]
+
+ size = np.zeros((*instance_ids.shape, 2))
+ object_semantic_annotation = np.zeros(instance_ids.shape)
+ for instance_id in np.unique(instance_ids):
+ if instance_id != 0:
+ h = (instance_ids == instance_id).any(axis=-1).sum()
+ w = (instance_ids == instance_id).any(axis=-2).sum()
+ size[pixel_to_object_mapping == instance_id] = (h, w)
+ object_semantic_annotation[
+ pixel_to_object_mapping == instance_id
+ ] = pixel_semantic_annotation[instance_ids == instance_id][0]
+
+ target = paddle.to_tensor(
+ np.concatenate(
+ [
+ heatmap[:, :, None], # 0
+ instance_ids[:, :, None], # 1
+ pixel_to_object_mapping[:, :, None], # 2
+ size, # 3-4
+ object_semantic_annotation[:, :, None], # 5
+ pixel_semantic_annotation[:, :, None], # 6
+ ],
+ axis=-1,
+ ),
+ dtype="float32",
+ )
+
+ if self.cache:
+ if self.mem16:
+ self.memory[item] = [
+ {k: v.astype("float16") for k, v in data.items()},
+ target,
+ ]
+ else:
+ self.memory[item] = [data, target]
+
+ else:
+ data, target = self.memory[item]
+ if self.mem16:
+ data = {k: v.astype("float32") for k, v in data.items()}
+
+ # Retrieve date sequences
+ if not self.cache or id_patch not in self.memory_dates.keys():
+ dates = {
+ s: paddle.to_tensor(self.get_dates(id_patch, s)) for s in self.sats
+ }
+ if self.cache:
+ self.memory_dates[id_patch] = dates
+ else:
+ dates = self.memory_dates[id_patch]
+
+ if self.mono_date is not None:
+ if isinstance(self.mono_date, int):
+ data = {s: data[s][self.mono_date].unsqueeze(0) for s in self.sats}
+ dates = {s: dates[s][self.mono_date] for s in self.sats}
+ else:
+ mono_delta = (self.mono_date - self.reference_date).days
+ mono_date = {
+ s: int((dates[s] - mono_delta).abs().argmin()) for s in self.sats
+ }
+ data = {s: data[s][mono_date[s]].unsqueeze(0) for s in self.sats}
+ dates = {s: dates[s][mono_date[s]] for s in self.sats}
+
+ if self.mem16:
+ data = {k: v.astype("float32") for k, v in data.items()}
+
+ if len(self.sats) == 1:
+ data = data[self.sats[0]]
+ dates = dates[self.sats[0]]
+
+ return (data, dates), target
+
+
+def prepare_dates(date_dict, reference_date):
+ d = pd.DataFrame().from_dict(date_dict, orient="index")
+ d = d[0].apply(
+ lambda x: (
+ datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]))
+ - reference_date
+ ).days
+ )
+ return d.values
+
+
+def compute_norm_vals(folder, sat):
+ norm_vals = {}
+ for fold in range(1, 6):
+ dt = PASTIS_Dataset(folder=folder, norm=False, folds=[fold], sats=[sat])
+ means = []
+ stds = []
+ for i, b in enumerate(dt):
+ print("{}/{}".format(i, len(dt)), end="\r")
+ data = b[0][0][sat] # T x C x H x W
+ data = data.transpose([1, 0, 2, 3]) # C x T x H x W
+ means.append(data.reshape([data.shape[0], -1]).mean(axis=-1).numpy())
+ stds.append(data.reshape([data.shape[0], -1]).std(axis=-1).numpy())
+
+ mean = np.stack(means).mean(axis=0).astype(float)
+ std = np.stack(stds).mean(axis=0).astype(float)
+
+ norm_vals["Fold_{}".format(fold)] = dict(mean=list(mean), std=list(std))
+
+ with open(os.path.join(folder, "NORM_{}_patch.json".format(sat)), "w") as file:
+ file.write(json.dumps(norm_vals, indent=4))
diff --git a/examples/UTAE/src/learning/metrics.py b/examples/UTAE/src/learning/metrics.py
new file mode 100644
index 000000000..d292374b5
--- /dev/null
+++ b/examples/UTAE/src/learning/metrics.py
@@ -0,0 +1,53 @@
+"""
+Metrics utilities (Paddle Version)
+"""
+import numpy as np
+
+"""
+Compute per-class and overall metrics from confusion matrix
+"""
+
+
+def confusion_matrix_analysis(cm):
+
+ n_classes = cm.shape[0]
+
+ # Overall accuracy
+ acc = np.diag(cm).sum() / (cm.sum() + 1e-15)
+
+ # Per-class metrics
+ per_class_metrics = {}
+ ious = []
+
+ for i in range(n_classes):
+ # True positives, false positives, false negatives
+ tp = cm[i, i]
+ fp = cm[:, i].sum() - tp
+ fn = cm[i, :].sum() - tp
+
+ # Precision, recall, F1
+ precision = tp / (tp + fp + 1e-15)
+ recall = tp / (tp + fn + 1e-15)
+ f1 = 2 * precision * recall / (precision + recall + 1e-15)
+
+ # IoU
+ union = tp + fp + fn
+ iou = tp / (union + 1e-15)
+ ious.append(iou)
+
+ per_class_metrics[f"class_{i}"] = {
+ "precision": precision,
+ "recall": recall,
+ "f1": f1,
+ "iou": iou,
+ }
+
+ # Mean metrics
+ mean_iou = np.mean(ious)
+
+ return {
+ "overall_accuracy": acc,
+ "mean_iou": mean_iou,
+ "per_class": per_class_metrics,
+ "confusion_matrix": cm.tolist(),
+ }
diff --git a/examples/UTAE/src/learning/miou.py b/examples/UTAE/src/learning/miou.py
new file mode 100644
index 000000000..622e61f63
--- /dev/null
+++ b/examples/UTAE/src/learning/miou.py
@@ -0,0 +1,78 @@
+"""
+IoU metric computation (Paddle Version)
+"""
+import numpy as np
+import paddle
+
+
+class IoU:
+ def __init__(self, num_classes, ignore_index=-1, cm_device="cpu"):
+ self.num_classes = num_classes
+ self.ignore_index = ignore_index
+ self.cm_device = cm_device
+ self.confusion_matrix = np.zeros((num_classes, num_classes))
+
+ """
+ Add predictions and targets to confusion matrix
+ """
+
+ def add(self, pred, target):
+ # Convert to numpy if tensors
+ if isinstance(pred, paddle.Tensor):
+ pred = pred.cpu().numpy()
+ if isinstance(target, paddle.Tensor):
+ target = target.cpu().numpy()
+
+ # Flatten arrays
+ pred = pred.flatten()
+ target = target.flatten()
+
+ # Remove ignore index
+ if self.ignore_index is not None:
+ mask = target != self.ignore_index
+ pred = pred[mask]
+ target = target[mask]
+
+ # Compute confusion matrix
+ for t, p in zip(target.flatten(), pred.flatten()):
+ if 0 <= t < self.num_classes and 0 <= p < self.num_classes:
+ self.confusion_matrix[t, p] += 1
+
+ """
+ Get mean IoU and accuracy from confusion matrix
+ """
+
+ def get_miou_acc(self):
+
+ # Overall accuracy
+ acc = np.diag(self.confusion_matrix).sum() / (
+ self.confusion_matrix.sum() + 1e-15
+ )
+
+ # Per-class IoU
+ ious = []
+ for i in range(self.num_classes):
+ intersection = self.confusion_matrix[i, i]
+ union = (
+ self.confusion_matrix[i, :].sum()
+ + self.confusion_matrix[:, i].sum()
+ - intersection
+ )
+
+ if union > 0:
+ ious.append(intersection / union)
+ else:
+ ious.append(0.0)
+
+ # Mean IoU
+ miou = np.mean(ious)
+
+ return miou, acc
+
+ """
+ Reset confusion matrix
+ """
+
+ def reset(self):
+
+ self.confusion_matrix.fill(0)
diff --git a/examples/UTAE/src/learning/weight_init.py b/examples/UTAE/src/learning/weight_init.py
new file mode 100644
index 000000000..b79c04f4f
--- /dev/null
+++ b/examples/UTAE/src/learning/weight_init.py
@@ -0,0 +1,24 @@
+"""
+Weight initialization utilities (Paddle Version)
+"""
+import paddle.nn as nn
+
+"""
+Initialize model weights
+"""
+
+
+def weight_init(model):
+
+ for layer in model.sublayers():
+ if isinstance(layer, (nn.Conv2D, nn.Conv1D)):
+ nn.initializer.XavierUniform()(layer.weight)
+ if layer.bias is not None:
+ nn.initializer.Constant(0.0)(layer.bias)
+ elif isinstance(layer, (nn.BatchNorm2D, nn.BatchNorm1D, nn.GroupNorm)):
+ nn.initializer.Constant(1.0)(layer.weight)
+ nn.initializer.Constant(0.0)(layer.bias)
+ elif isinstance(layer, nn.Linear):
+ nn.initializer.XavierUniform()(layer.weight)
+ if layer.bias is not None:
+ nn.initializer.Constant(0.0)(layer.bias)
diff --git a/examples/UTAE/src/model_utils.py b/examples/UTAE/src/model_utils.py
new file mode 100644
index 000000000..bda3eac5a
--- /dev/null
+++ b/examples/UTAE/src/model_utils.py
@@ -0,0 +1,109 @@
+"""
+Model utilities (Paddle Version)
+"""
+from src.backbones.utae import UTAE
+from src.backbones.utae import RecUNet
+
+"""
+Get the model based on configuration
+"""
+
+
+def get_model(config, mode="semantic"):
+
+ if mode == "panoptic":
+ # For panoptic segmentation, create PaPs model
+ if config.backbone == "utae":
+ from src.panoptic.paps import PaPs
+
+ encoder = UTAE(
+ input_dim=10, # PASTIS has 10 spectral bands
+ encoder_widths=eval(config.encoder_widths),
+ decoder_widths=eval(config.decoder_widths),
+ out_conv=eval(config.out_conv),
+ str_conv_k=config.str_conv_k,
+ str_conv_s=config.str_conv_s,
+ str_conv_p=config.str_conv_p,
+ agg_mode=config.agg_mode,
+ encoder_norm=config.encoder_norm,
+ n_head=config.n_head,
+ d_model=config.d_model,
+ d_k=config.d_k,
+ encoder=True, # Important: set to True for PaPs
+ return_maps=True, # Important: return feature maps
+ pad_value=config.pad_value,
+ padding_mode=config.padding_mode,
+ )
+
+ model = PaPs(
+ encoder=encoder,
+ num_classes=config.num_classes,
+ shape_size=config.shape_size,
+ mask_conv=config.mask_conv,
+ min_confidence=config.min_confidence,
+ min_remain=config.min_remain,
+ mask_threshold=config.mask_threshold,
+ )
+ else:
+ raise NotImplementedError(
+ f"Backbone {config.backbone} not implemented for panoptic mode"
+ )
+ elif config.model == "utae":
+ model = UTAE(
+ input_dim=10, # Sentinel-2 has 10 bands
+ encoder_widths=eval(config.encoder_widths),
+ decoder_widths=eval(config.decoder_widths),
+ out_conv=eval(config.out_conv),
+ str_conv_k=config.str_conv_k,
+ str_conv_s=config.str_conv_s,
+ str_conv_p=config.str_conv_p,
+ agg_mode=config.agg_mode,
+ encoder_norm=config.encoder_norm,
+ n_head=config.n_head,
+ d_model=config.d_model,
+ d_k=config.d_k,
+ pad_value=config.pad_value,
+ padding_mode=config.padding_mode,
+ )
+ elif config.model == "uconvlstm":
+ model = RecUNet(
+ input_dim=10,
+ encoder_widths=eval(config.encoder_widths),
+ decoder_widths=eval(config.decoder_widths),
+ out_conv=eval(config.out_conv),
+ str_conv_k=config.str_conv_k,
+ str_conv_s=config.str_conv_s,
+ str_conv_p=config.str_conv_p,
+ temporal="lstm",
+ encoder_norm=config.encoder_norm,
+ padding_mode=config.padding_mode,
+ pad_value=config.pad_value,
+ )
+ elif config.model == "buconvlstm":
+ model = RecUNet(
+ input_dim=10,
+ encoder_widths=eval(config.encoder_widths),
+ decoder_widths=eval(config.decoder_widths),
+ out_conv=eval(config.out_conv),
+ str_conv_k=config.str_conv_k,
+ str_conv_s=config.str_conv_s,
+ str_conv_p=config.str_conv_p,
+ temporal="blstm",
+ encoder_norm=config.encoder_norm,
+ padding_mode=config.padding_mode,
+ pad_value=config.pad_value,
+ )
+ else:
+ raise ValueError(f"Unknown model: {config.model}")
+
+ return model
+
+
+"""
+Get number of trainable parameters
+"""
+
+
+def get_ntrainparams(model):
+
+ return sum(p.numel() for p in model.parameters() if not p.stop_gradient)
diff --git a/examples/UTAE/src/panoptic/FocalLoss.py b/examples/UTAE/src/panoptic/FocalLoss.py
new file mode 100644
index 000000000..ab8261115
--- /dev/null
+++ b/examples/UTAE/src/panoptic/FocalLoss.py
@@ -0,0 +1,58 @@
+"""
+Converted to PaddlePaddle
+"""
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class FocalLoss(nn.Layer):
+ def __init__(self, gamma=0, alpha=None, size_average=True, ignore_label=None):
+ super(FocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ if isinstance(alpha, (float, int)):
+ self.alpha = paddle.to_tensor([alpha, 1 - alpha])
+ if isinstance(alpha, list):
+ self.alpha = paddle.to_tensor(alpha)
+ self.size_average = size_average
+ self.ignore_label = ignore_label
+
+ def forward(self, input, target):
+ if input.ndim > 2:
+ input = input.reshape(
+ [input.shape[0], input.shape[1], -1]
+ ) # N,C,H,W => N,C,H*W
+ input = input.transpose([0, 2, 1]) # N,C,H*W => N,H*W,C
+ input = input.reshape([-1, input.shape[2]]) # N,H*W,C => N*H*W,C
+ target = target.reshape([-1, 1])
+
+ if input.squeeze(1).ndim == 1:
+ logpt = F.sigmoid(input)
+ logpt = logpt.flatten()
+ else:
+ logpt = F.log_softmax(input, axis=-1)
+ logpt = paddle.gather_nd(
+ logpt,
+ paddle.stack([paddle.arange(logpt.shape[0]), target.squeeze()], axis=1),
+ )
+ logpt = logpt.flatten()
+
+ pt = paddle.exp(logpt)
+
+ if self.alpha is not None:
+ if self.alpha.dtype != input.dtype:
+ self.alpha = self.alpha.astype(input.dtype)
+ at = paddle.gather(self.alpha, target.flatten().astype("int64"))
+ logpt = logpt * at
+
+ loss = -1 * (1 - pt) ** self.gamma * logpt
+
+ if self.ignore_label is not None:
+ valid_mask = target[:, 0] != self.ignore_label
+ loss = loss[valid_mask]
+
+ if self.size_average:
+ return loss.mean()
+ else:
+ return loss.sum()
diff --git a/examples/UTAE/src/panoptic/geom_utils.py b/examples/UTAE/src/panoptic/geom_utils.py
new file mode 100644
index 000000000..03b9bddc6
--- /dev/null
+++ b/examples/UTAE/src/panoptic/geom_utils.py
@@ -0,0 +1,75 @@
+"""
+Geometric utilities (Paddle Version)
+Converted to PaddlePaddle
+"""
+
+import numpy as np
+import paddle
+
+
+def get_bbox(bin_mask):
+ """Input single (H,W) bin mask"""
+ if isinstance(bin_mask, paddle.Tensor):
+ xl, xr = paddle.nonzero(bin_mask.sum(axis=-2))[0][[0, -1]]
+ yt, yb = paddle.nonzero(bin_mask.sum(axis=-1))[0][[0, -1]]
+ return paddle.stack([xl, yt, xr, yb])
+ else:
+ xl, xr = np.where(bin_mask.sum(axis=-2))[0][[0, -1]]
+ yt, yb = np.where(bin_mask.sum(axis=-1))[0][[0, -1]]
+ return np.stack([xl, yt, xr, yb])
+
+
+def bbox_area(bbox):
+ """Input (N,4) set of bounding boxes"""
+ out = bbox.astype("float32")
+ return (out[:, 2] - out[:, 0]) * (out[:, 3] - out[:, 1])
+
+
+def intersect(box_a, box_b):
+ """
+ taken from https://github.com/amdegroot/ssd.pytorch
+ We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.shape[0]
+ B = box_b.shape[0]
+ max_xy = paddle.minimum(
+ box_a[:, 2:].unsqueeze(1).expand([A, B, 2]),
+ box_b[:, 2:].unsqueeze(0).expand([A, B, 2]),
+ )
+ min_xy = paddle.maximum(
+ box_a[:, :2].unsqueeze(1).expand([A, B, 2]),
+ box_b[:, :2].unsqueeze(0).expand([A, B, 2]),
+ )
+ inter = paddle.clip((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def bbox_iou(bbox1, bbox2):
+ """Two sets of (N,4) bounding boxes"""
+ area1 = bbox_area(bbox1)
+ area2 = bbox_area(bbox2)
+ inter = paddle.diag(intersect(bbox1, bbox2))
+ union = area1 + area2 - inter
+ valid_mask = union != 0
+ return inter[valid_mask] / union[valid_mask]
+
+
+def bbox_validzone(bbox, shape):
+ """Given an image shape, get the coordinate (in the bbox reference)
+ of the pixels that are within the image boundaries"""
+ H, W = shape
+ wt, ht, wb, hb = bbox
+
+ val_ht = -ht if ht < 0 else 0
+ val_wt = -wt if wt < 0 else 0
+ val_hb = H - ht if hb > H else hb - ht
+ val_wb = W - wt if wb > W else wb - wt
+ return (val_wt, val_ht, val_wb, val_hb)
diff --git a/examples/UTAE/src/panoptic/metrics.py b/examples/UTAE/src/panoptic/metrics.py
new file mode 100644
index 000000000..e410aea4c
--- /dev/null
+++ b/examples/UTAE/src/panoptic/metrics.py
@@ -0,0 +1,201 @@
+"""
+Panoptic Metrics (Paddle Version)
+Converted to PaddlePaddle
+"""
+
+import paddle
+
+
+class PanopticMeter:
+ """
+ Meter class for the panoptic metrics as defined by Kirilov et al. :
+ Segmentation Quality (SQ)
+ Recognition Quality (RQ)
+ Panoptic Quality (PQ)
+ The behavior of this meter mimics that of torchnet meters, each predicted batch
+ is added via the add method and the global metrics are retrieved with the value
+ method.
+ Args:
+ num_classes (int): Number of semantic classes (including background and void class).
+ void_label (int): Label for the void class (default 19).
+ background_label (int): Label for the background class (default 0).
+ iou_threshold (float): Threshold used on the IoU of the true vs predicted
+ instance mask. Above the threshold a true instance is counted as True Positive.
+ """
+
+ def __init__(
+ self, num_classes=20, background_label=0, void_label=19, iou_threshold=0.5
+ ):
+
+ self.num_classes = num_classes
+ self.iou_threshold = iou_threshold
+ self.class_list = [c for c in range(num_classes) if c != background_label]
+ self.void_label = void_label
+ if void_label is not None:
+ self.class_list = [c for c in self.class_list if c != void_label]
+ self.counts = paddle.zeros([len(self.class_list), 3])
+ self.cumulative_ious = paddle.zeros([len(self.class_list)])
+
+ def add(self, predictions, target):
+ # Split target tensor - equivalent to torch.split
+ target_splits = paddle.split(target, [1, 1, 1, 2, 1, 1], axis=-1)
+ (
+ target_heatmap,
+ true_instances,
+ zones,
+ size,
+ sem_obj,
+ sem_pix,
+ ) = target_splits
+
+ instance_true = true_instances.squeeze(-1)
+ semantic_true = sem_pix.squeeze(-1)
+
+ instance_pred = predictions["pano_instance"]
+
+ # Handle case when pano_semantic is None (when pseudo_nms=False)
+ if predictions["pano_semantic"] is not None:
+ semantic_pred = predictions["pano_semantic"].argmax(axis=1)
+ else:
+ # Return early with zero metrics when no panoptic predictions available
+ return
+
+ if self.void_label is not None:
+ void_masks = (semantic_true == self.void_label).astype("float32")
+
+ # Ignore Void Objects
+ for batch_idx in range(void_masks.shape[0]):
+ void_mask = void_masks[batch_idx]
+ if void_mask.any():
+ void_instances = instance_true[batch_idx] * void_mask
+ unique_void, void_counts = paddle.unique(
+ void_instances, return_counts=True
+ )
+
+ for void_inst_id, void_inst_area in zip(unique_void, void_counts):
+ if void_inst_id == 0:
+ continue
+
+ pred_instances = instance_pred[batch_idx]
+ unique_pred, pred_counts = paddle.unique(
+ pred_instances, return_counts=True
+ )
+
+ for pred_inst_id, pred_inst_area in zip(
+ unique_pred, pred_counts
+ ):
+ if pred_inst_id == 0:
+ continue
+ inter = (
+ (instance_true[batch_idx] == void_inst_id)
+ * (instance_pred[batch_idx] == pred_inst_id)
+ ).sum()
+ iou = inter.astype("float32") / (
+ void_inst_area + pred_inst_area - inter
+ ).astype("float32")
+ if iou > self.iou_threshold:
+ instance_pred[batch_idx] = paddle.where(
+ instance_pred[batch_idx] == pred_inst_id,
+ paddle.to_tensor(0),
+ instance_pred[batch_idx],
+ )
+ semantic_pred[batch_idx] = paddle.where(
+ instance_pred[batch_idx] == pred_inst_id,
+ paddle.to_tensor(0),
+ semantic_pred[batch_idx],
+ )
+
+ # Ignore Void Pixels
+ instance_pred = paddle.where(void_masks, paddle.to_tensor(0), instance_pred)
+ semantic_pred = paddle.where(void_masks, paddle.to_tensor(0), semantic_pred)
+
+ # Compute metrics for each class
+ for i, class_id in enumerate(self.class_list):
+ TP = 0
+ n_preds = 0
+ n_true = 0
+ ious = []
+
+ for batch_idx in range(instance_true.shape[0]):
+ instance_mask = instance_true[batch_idx]
+ class_mask_gt = (semantic_true[batch_idx] == class_id).astype("float32")
+ class_mask_p = (semantic_pred[batch_idx] == class_id).astype("float32")
+
+ pred_class_instances = instance_pred[batch_idx] * class_mask_p
+ true_class_instances = instance_mask * class_mask_gt
+
+ n_preds += (
+ int(paddle.unique(pred_class_instances).shape[0]) - 1
+ ) # do not count 0
+ n_true += int(paddle.unique(true_class_instances).shape[0]) - 1
+
+ if n_preds == 0 or n_true == 0:
+ continue # no true positives in that case
+
+ unique_true, true_counts = paddle.unique(
+ true_class_instances, return_counts=True
+ )
+ for true_inst_id, true_inst_area in zip(unique_true, true_counts):
+ if true_inst_id == 0: # masked segments
+ continue
+
+ unique_pred, pred_counts = paddle.unique(
+ pred_class_instances, return_counts=True
+ )
+ for pred_inst_id, pred_inst_area in zip(unique_pred, pred_counts):
+ if pred_inst_id == 0:
+ continue
+ inter = (
+ (instance_mask == true_inst_id)
+ * (instance_pred[batch_idx] == pred_inst_id)
+ ).sum()
+ iou = inter.astype("float32") / (
+ true_inst_area + pred_inst_area - inter
+ ).astype("float32")
+
+ if iou > self.iou_threshold:
+ TP += 1
+ ious.append(iou)
+
+ FP = n_preds - TP
+ FN = n_true - TP
+
+ self.counts[i] += paddle.to_tensor([TP, FP, FN], dtype="float32")
+ if len(ious) > 0:
+ self.cumulative_ious[i] += paddle.stack(ious).sum()
+
+ def value(self, per_class=False):
+ TP, FP, FN = paddle.split(self.counts.astype("float32"), 3, axis=-1)
+ SQ = self.cumulative_ious / TP.squeeze()
+
+ # Handle NaN and Inf values
+ nan_mask = paddle.isnan(SQ) | paddle.isinf(SQ)
+ SQ = paddle.where(nan_mask, paddle.to_tensor(0.0), SQ)
+
+ RQ = TP / (TP + 0.5 * FP + 0.5 * FN)
+ PQ = SQ * RQ.squeeze(-1)
+
+ if per_class:
+ return SQ, RQ, PQ
+ else:
+ valid_mask = ~paddle.isnan(PQ)
+ if valid_mask.any():
+ return (
+ SQ[valid_mask].mean(),
+ RQ[valid_mask].mean(),
+ PQ[valid_mask].mean(),
+ )
+ else:
+ return (
+ paddle.to_tensor(0.0),
+ paddle.to_tensor(0.0),
+ paddle.to_tensor(0.0),
+ )
+
+ def get_table(self):
+ table = (
+ paddle.concat([self.counts, self.cumulative_ious[:, None]], axis=-1)
+ .cpu()
+ .numpy()
+ )
+ return table
diff --git a/examples/UTAE/src/panoptic/paps.py b/examples/UTAE/src/panoptic/paps.py
new file mode 100644
index 000000000..95dc331d7
--- /dev/null
+++ b/examples/UTAE/src/panoptic/paps.py
@@ -0,0 +1,503 @@
+"""
+PaPs Implementation (Paddle Version)
+Converted to PaddlePaddle
+"""
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from src.backbones.utae import ConvLayer
+
+
+class PaPs(nn.Layer):
+ """
+ Implementation of the Parcel-as-Points Module (PaPs) for panoptic segmentation of agricultural
+ parcels from satellite image time series.
+ Args:
+ encoder (nn.Layer): Backbone encoding network. The encoder is expected to return
+ a feature map at the same resolution as the input images and a list of feature maps
+ of lower resolution.
+ num_classes (int): Number of classes (including stuff and void classes).
+ shape_size (int): S hyperparameter defining the shape of the local patch.
+ mask_conv (bool): If False no residual CNN is applied after combination of
+ the predicted shape and the cropped saliency (default True)
+ min_confidence (float): Cut-off confidence level for the pseudo NMS (predicted instances with
+ lower condidence will not be included in the panoptic prediction).
+ min_remain (float): Hyperparameter of the pseudo-NMS that defines the fraction of a candidate instance mask
+ that needs to be new to be included in the final panoptic prediction (default 0.5).
+ mask_threshold (float): Binary threshold for instance masks (default 0.4)
+ """
+
+ def __init__(
+ self,
+ encoder,
+ num_classes=20,
+ shape_size=16,
+ mask_conv=True,
+ min_confidence=0.2,
+ min_remain=0.5,
+ mask_threshold=0.4,
+ ):
+
+ super(PaPs, self).__init__()
+ self.encoder = encoder
+ self.shape_size = shape_size
+ self.num_classes = num_classes
+ self.min_scale = 1 / shape_size
+ self.register_buffer("min_confidence", paddle.to_tensor([min_confidence]))
+ self.min_remain = min_remain
+ self.mask_threshold = mask_threshold
+ self.center_extractor = CenterExtractor()
+
+ enc_dim = encoder.enc_dim
+ stack_dim = encoder.stack_dim
+ self.heatmap_conv = nn.Sequential(
+ ConvLayer(
+ nkernels=[enc_dim, 32, 1],
+ last_relu=False,
+ k=3,
+ p=1,
+ padding_mode="reflect",
+ ),
+ nn.Sigmoid(),
+ )
+
+ self.saliency_conv = ConvLayer(
+ nkernels=[enc_dim, 32, 1], last_relu=False, k=3, p=1, padding_mode="reflect"
+ )
+
+ self.shape_mlp = nn.Sequential(
+ nn.Linear(stack_dim, stack_dim // 2),
+ nn.BatchNorm1D(stack_dim // 2),
+ nn.ReLU(),
+ nn.Linear(stack_dim // 2, shape_size**2),
+ )
+
+ self.size_mlp = nn.Sequential(
+ nn.Linear(stack_dim, stack_dim // 2),
+ nn.BatchNorm1D(stack_dim // 2),
+ nn.ReLU(),
+ nn.Linear(stack_dim // 2, stack_dim // 4),
+ nn.BatchNorm1D(stack_dim // 4),
+ nn.ReLU(),
+ nn.Linear(stack_dim // 4, 2),
+ nn.Softplus(),
+ )
+
+ self.class_mlp = nn.Sequential(
+ nn.Linear(stack_dim, stack_dim // 2),
+ nn.BatchNorm1D(stack_dim // 2),
+ nn.ReLU(),
+ nn.Linear(stack_dim // 2, stack_dim // 4),
+ nn.Linear(stack_dim // 4, num_classes),
+ )
+
+ if mask_conv:
+ self.mask_cnn = nn.Sequential(
+ nn.Conv2D(in_channels=1, out_channels=16, kernel_size=3, padding=1),
+ nn.GroupNorm(num_channels=16, num_groups=1),
+ nn.ReLU(),
+ nn.Conv2D(in_channels=16, out_channels=16, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.Conv2D(in_channels=16, out_channels=1, kernel_size=3, padding=1),
+ )
+ else:
+ self.mask_cnn = None
+
+ def forward(
+ self,
+ input,
+ batch_positions=None,
+ zones=None,
+ pseudo_nms=True,
+ heatmap_only=False,
+ ):
+ """
+ Args:
+ input (tensor): Input image time series.
+ batch_positions (tensor): Date sequence of the batch images.
+ zones (tensor, Optional): Tensor that defines the mapping between each pixel position and
+ the "closest" center during training (see paper paragraph Centerpoint detection). This mapping
+ is used at train time to predict and supervise at most one prediction
+ per ground truth object for efficiency.
+ When not provided all predicted centers receive supervision.
+ pseudo_nms (bool): If True performs pseudo_nms to produce a panoptic prediction,
+ otherwise the model returns potentially overlapping instance segmentation masks (default True).
+ heatmap_only (bool): If True the model only returns the centerness heatmap. Can be useful for some
+ warmup epochs of the centerness prediction, as all the rest hinges on this.
+
+ Returns:
+ predictions (dict[tensor]): A dictionary of predictions with the following keys:
+ center_mask (B,H,W) Binary mask of centers.
+ saliency (B,1,H,W) Global Saliency.
+ heatmap (B,1,H,W) Predicted centerness heatmap.
+ semantic (M, K) Predicted class scores for each center (with M the number of predicted centers).
+ size (M, 2) Predicted sizes for each center.
+ confidence (M,1) Predicted centerness for each center.
+ centerness (M,1) Predicted centerness for each center.
+ instance_masks List of N binary masks of varying shape.
+ instance_boxes (N, 4) Coordinates of the N bounding boxes.
+ pano_instance (B,H,W) Predicted instance id for each pixel.
+ pano_semantic (B,K,H,W) Predicted class score for each pixel.
+
+ """
+ out, maps = self.encoder(input, batch_positions=batch_positions)
+
+ # Global Predictions
+ heatmap = self.heatmap_conv(out)
+ saliency = self.saliency_conv(out)
+
+ center_mask, _ = self.center_extractor(
+ heatmap, zones=zones
+ ) # (B,H,W) mask of N detected centers
+ # Don't squeeze batch dimension to maintain consistency with loss function expectations
+
+ if heatmap_only:
+ predictions = dict(
+ center_mask=center_mask,
+ saliency=None,
+ heatmap=heatmap,
+ semantic=None,
+ size=None,
+ offsets=None,
+ confidence=None,
+ instance_masks=None,
+ instance_boxes=None,
+ pano_instance=None,
+ pano_semantic=None,
+ )
+ return predictions
+
+ # Retrieve info of detected centers
+ H, W = heatmap.shape[-2:]
+ # center_mask is now always 3D (B, H, W)
+ center_indices = paddle.nonzero(center_mask, as_tuple=False)
+
+ if center_indices.shape[0] == 0:
+ # Handle case where no centers detected
+ center_batch = paddle.empty([0], dtype="int64")
+ center_h = paddle.empty([0], dtype="int64")
+ center_w = paddle.empty([0], dtype="int64")
+ center_positions = paddle.empty([0, 2], dtype="int64")
+ else:
+ # center_mask is (B, H, W), so indices are (N, 3)
+ center_batch = center_indices[:, 0]
+ center_h = center_indices[:, 1]
+ center_w = center_indices[:, 2]
+ center_positions = paddle.stack([center_h, center_w], axis=1)
+
+ # Construct multi-level feature stack for centers
+ stack = []
+ for i, m in enumerate(maps):
+ h_mask = center_h // (2 ** (len(maps) - 1 - i))
+ # Assumes resolution is divided by 2 at each level
+ w_mask = center_w // (2 ** (len(maps) - 1 - i))
+ m = m.transpose([0, 2, 3, 1])
+ # Use paddle.gather_nd for advanced indexing
+ indices = paddle.stack([center_batch, h_mask, w_mask], axis=1)
+ stack.append(paddle.gather_nd(m, indices))
+ stack = paddle.concat(stack, axis=1)
+
+ # Center-level predictions
+ size = self.size_mlp(stack)
+ sem = self.class_mlp(stack)
+ shapes = self.shape_mlp(stack)
+ shapes = shapes.reshape([-1, 1, self.shape_size, self.shape_size])
+ # (N,1,S,S) instance shapes
+
+ # Extract centerness from heatmap at center positions
+ # Use gather_nd to extract values at specific positions
+ if center_h.shape[0] > 0:
+ # Create indices for gather_nd: [batch_idx, channel_idx, h_idx, w_idx]
+ batch_indices = center_batch.unsqueeze(1) # [N, 1]
+ channel_indices = paddle.zeros_like(batch_indices) # [N, 1] - channel 0
+ h_indices = center_h.unsqueeze(1) # [N, 1]
+ w_indices = center_w.unsqueeze(1) # [N, 1]
+ gather_indices = paddle.concat(
+ [batch_indices, channel_indices, h_indices, w_indices], axis=1
+ )
+ centerness = paddle.gather_nd(heatmap, gather_indices).unsqueeze(-1)
+ else:
+ centerness = paddle.empty([0, 1])
+ confidence = centerness
+
+ # Instance Boxes Assembling
+ ## Minimal box size of 1px
+ ## Combine clamped sizes and center positions to obtain box coordinates
+ clamp_size = size.detach().round().astype("int64").clip(min=1)
+ half_size = clamp_size // 2
+ remainder_size = clamp_size % 2
+ start_hw = center_positions - half_size
+ stop_hw = center_positions + half_size + remainder_size
+
+ instance_boxes = paddle.concat([start_hw, stop_hw], axis=1)
+ instance_boxes = paddle.clip(instance_boxes, min=0, max=H)
+ instance_boxes = instance_boxes[:, [1, 0, 3, 2]] # h,w,h,w to x,y,x,y
+
+ valid_start = paddle.clip(
+ -start_hw, min=0
+ ) # if h=-5 crop the shape mask before the 5th pixel
+ valid_stop = (stop_hw - start_hw) - paddle.clip(
+ stop_hw - H, min=0
+ ) # crop if h_stop > H
+
+ # Instance Masks Assembling
+ instance_masks = []
+ # Manual splitting to match PyTorch behavior exactly
+ # PyTorch: shapes.split(1, dim=0) gives list of [1, 1, S, S] tensors
+ # PaddlePaddle: paddle.split() behaves differently, use manual approach
+ for i in range(shapes.shape[0]):
+ s = shapes[i : i + 1] # [1, 1, S, S] - exactly like PyTorch split
+ h, w = clamp_size[i] # Box size
+ w_start, h_start, w_stop, h_stop = instance_boxes[
+ i
+ ] # Box coordinates (x,y,x,y format)
+ h_start_valid, w_start_valid = valid_start[i] # Part of the Box that lies
+ h_stop_valid, w_stop_valid = valid_stop[i] # within the image's extent
+
+ ## Resample local shape mask - match PyTorch exactly
+ # s is single shape [1, 1, shape_size, shape_size] from split
+ pred_mask = F.interpolate(s, size=[h.item(), w.item()], mode="bilinear")
+ pred_mask = pred_mask.squeeze(0) # Remove batch dim -> [1, h, w]
+ pred_mask = pred_mask[
+ :, h_start_valid:h_stop_valid, w_start_valid:w_stop_valid
+ ]
+
+ ## Crop saliency
+ batch_idx = int(center_batch[i]) # Ensure scalar index
+ crop_saliency = saliency[batch_idx, :, h_start:h_stop, w_start:w_stop]
+
+ ## Combine both
+ if self.mask_cnn is None:
+ pred_mask = F.sigmoid(pred_mask) * F.sigmoid(crop_saliency)
+ # Debug: print shape for mask_cnn is None case (only if needed)
+ # print(f"Debug - pred_mask shape (no mask_cnn): {pred_mask.shape}")
+ else:
+ pred_mask = pred_mask + crop_saliency
+ # Ensure pred_mask is [C, H, W] before mask_cnn
+ if pred_mask.ndim != 3:
+ raise ValueError(
+ f"pred_mask should be 3D [C,H,W], got shape {pred_mask.shape}"
+ )
+ pred_mask = F.sigmoid(pred_mask) * F.sigmoid(
+ self.mask_cnn(pred_mask.unsqueeze(0)).squeeze(0)
+ )
+
+ # Debug: print shape when appending to instance_masks (only if needed)
+ # print(f"Debug - appending pred_mask with shape: {pred_mask.shape}")
+ instance_masks.append(pred_mask)
+
+ # PSEUDO-NMS
+ if pseudo_nms:
+ panoptic_instance = []
+ panoptic_semantic = []
+ for b in range(center_mask.shape[0]): # iterate over elements of batch
+ panoptic_mask = paddle.zeros(center_mask[0].shape, dtype="float32")
+ semantic_mask = paddle.zeros(
+ [self.num_classes] + list(center_mask[0].shape), dtype="float32"
+ )
+
+ # Get indices of centers in this batch element - match PyTorch exactly
+ candidates = paddle.nonzero(center_batch == b).squeeze(-1)
+ if candidates.ndim == 0: # Handle single candidate case
+ candidates = candidates.unsqueeze(0)
+
+ if len(candidates) > 0:
+ # Sort by confidence descending - match PyTorch logic exactly
+ candidate_confidences = confidence[candidates].squeeze(-1)
+ if candidate_confidences.ndim == 0: # Handle single confidence case
+ candidate_confidences = candidate_confidences.unsqueeze(0)
+
+ # Use argsort to get indices, then get sorted values - match torch.sort behavior
+ sorted_indices = paddle.argsort(
+ candidate_confidences, descending=True
+ )
+ sorted_values = candidate_confidences[sorted_indices]
+
+ for n, (c, idx_in_candidates) in enumerate(
+ zip(sorted_values, sorted_indices)
+ ):
+ if c < self.min_confidence:
+ break
+ else:
+ # Get the actual index in the original candidates array
+ actual_idx = candidates[idx_in_candidates]
+
+ new_mask = paddle.zeros(
+ center_mask[0].shape, dtype="float32"
+ )
+ # Match PyTorch exactly: instance_masks[candidates[idx]].squeeze(0)
+ instance_mask = instance_masks[actual_idx]
+
+ # Robust squeeze to handle any extra dimensions - match PyTorch .squeeze(0)
+ while (
+ instance_mask.ndim > 2 and instance_mask.shape[0] == 1
+ ):
+ instance_mask = instance_mask.squeeze(0)
+
+ pred_mask_bin = (
+ instance_mask > self.mask_threshold
+ ).astype("float32")
+
+ # Get box coordinates first, before checking if mask is valid
+ xtl, ytl, xbr, ybr = instance_boxes[actual_idx]
+
+ if pred_mask_bin.sum() > 0:
+ # Simple assignment like PyTorch - should work now with correct shapes
+ new_mask[ytl:ybr, xtl:xbr] = pred_mask_bin
+
+ # Check for overlap - match PyTorch logic exactly
+ if ((new_mask != 0) * (panoptic_mask != 0)).any():
+ n_total = (new_mask != 0).sum()
+ non_overlaping_mask = (new_mask != 0) * (
+ panoptic_mask == 0
+ )
+ n_new = non_overlaping_mask.sum().astype("float32")
+ if n_new / n_total > self.min_remain:
+ # Direct assignment like PyTorch - fix data flow issue
+ panoptic_mask = paddle.where(
+ non_overlaping_mask,
+ paddle.full_like(panoptic_mask, n + 1),
+ panoptic_mask,
+ )
+
+ # Semantic assignment - match PyTorch exactly using advanced indexing
+ sem_values = sem[actual_idx] # [num_classes]
+ # PyTorch: semantic_mask[:, non_overlaping_mask] = sem[candidates[idx]][:, None]
+ # Find positions where mask is True
+ mask_positions = paddle.nonzero(
+ non_overlaping_mask
+ ) # [N, 2]
+ if len(mask_positions) > 0:
+ # Extract coordinates
+ h_coords = mask_positions[:, 0] # [N]
+ w_coords = mask_positions[:, 1] # [N]
+ # Assign semantic values to all mask positions
+ for i in range(self.num_classes):
+ semantic_mask[
+ i, h_coords, w_coords
+ ] = sem_values[i]
+ else:
+ # No overlap case - direct assignment
+ new_mask_bool = new_mask != 0
+ panoptic_mask = paddle.where(
+ new_mask_bool,
+ paddle.full_like(panoptic_mask, n + 1),
+ panoptic_mask,
+ )
+
+ # Semantic assignment - match PyTorch exactly using advanced indexing
+ sem_values = sem[actual_idx] # [num_classes]
+ # PyTorch: semantic_mask[:, (new_mask != 0)] = sem[candidates[idx]][:, None]
+ # Find positions where mask is True
+ mask_positions = paddle.nonzero(
+ new_mask_bool
+ ) # [N, 2]
+ if len(mask_positions) > 0:
+ # Extract coordinates
+ h_coords = mask_positions[:, 0] # [N]
+ w_coords = mask_positions[:, 1] # [N]
+ # Assign semantic values to all mask positions
+ for i in range(self.num_classes):
+ semantic_mask[
+ i, h_coords, w_coords
+ ] = sem_values[i]
+
+ panoptic_instance.append(panoptic_mask)
+ panoptic_semantic.append(semantic_mask)
+ panoptic_instance = paddle.stack(panoptic_instance, axis=0)
+ panoptic_semantic = paddle.stack(panoptic_semantic, axis=0)
+ else:
+ panoptic_instance = None
+ panoptic_semantic = None
+
+ predictions = dict(
+ center_mask=center_mask,
+ saliency=saliency,
+ heatmap=heatmap,
+ semantic=sem,
+ size=size,
+ confidence=confidence,
+ centerness=centerness,
+ instance_masks=instance_masks,
+ instance_boxes=instance_boxes,
+ pano_instance=panoptic_instance,
+ pano_semantic=panoptic_semantic,
+ )
+
+ return predictions
+
+
+class CenterExtractor(nn.Layer):
+ def __init__(self):
+ """
+ Module for local maxima extraction
+ """
+ super(CenterExtractor, self).__init__()
+ self.pool = nn.MaxPool2D(kernel_size=3, stride=1, padding=1)
+
+ def forward(self, input, zones=None):
+ """
+ Args:
+ input (tensor): Centerness heatmap
+ zones (tensor, Optional): Tensor that defines the mapping between each pixel position and
+ the "closest" center during training (see paper paragraph Centerpoint detection).
+ If provided, the highest local maxima in each zone is kept. As a result at most one
+ prediction is made per ground truth object.
+ If not provided, all local maxima are returned.
+ """
+ if zones is not None:
+ # Note: torch_scatter functionality needs to be implemented using native Paddle operations
+ # This is a simplified implementation - may need refinement for exact equivalence
+ masks = []
+ for b in range(input.shape[0]):
+ x = input[b].flatten()
+ zones_flat = zones[b].flatten().astype("int64")
+
+ # Group by zone indices and find max in each zone
+ # This is a simplified approach - actual scatter_max would be more efficient
+ unique_zones = paddle.unique(zones_flat)
+ mask = paddle.zeros_like(x)
+
+ for zone in unique_zones:
+ if zone >= 0: # Skip invalid zones
+ zone_mask = zones_flat == zone
+ zone_values = x[zone_mask]
+ if len(zone_values) > 0:
+ max_idx_in_zone = paddle.argmax(zone_values)
+ global_indices = paddle.nonzero(zone_mask).squeeze()
+ if global_indices.ndim == 0:
+ global_indices = global_indices.unsqueeze(0)
+ max_global_idx = global_indices[max_idx_in_zone]
+ mask[max_global_idx] = 1
+
+ # Ensure zones[b] is 2D - remove last dimension if it's 1
+ zone_shape = zones[b].shape
+ # print(f"Debug: original zone_shape: {zone_shape}")
+ if len(zone_shape) == 3 and zone_shape[-1] == 1:
+ zone_shape = zone_shape[:-1] # (H, W, 1) -> (H, W)
+ # print(f"Debug: adjusted zone_shape: {zone_shape}")
+ reshaped_mask = mask.reshape(zone_shape)
+ # print(f"Debug: reshaped_mask shape: {reshaped_mask.shape}")
+ final_mask = reshaped_mask.unsqueeze(0)
+ # print(f"Debug: final_mask shape after unsqueeze: {final_mask.shape}")
+ masks.append(final_mask)
+ centermask = paddle.stack(masks, axis=0).astype("bool")
+ # print(f"Debug: centermask shape after stack: {centermask.shape}")
+ # Ensure centermask is (B, H, W) - remove any singleton dimensions except batch
+ while len(centermask.shape) > 3:
+ if centermask.shape[1] == 1:
+ centermask = centermask.squeeze(1)
+ # print(f"Debug: centermask shape after squeeze(1): {centermask.shape}")
+ else:
+ break
+ else:
+ centermask = input == self.pool(input)
+ no_valley = input > input.mean()
+ centermask = centermask * no_valley
+ # Ensure centermask is (B, H, W) by squeezing channel dimension if it's 1
+ if centermask.shape[1] == 1:
+ centermask = centermask.squeeze(1)
+
+ n_centers = int(centermask.sum().detach().cpu().item())
+ return centermask, n_centers
diff --git a/examples/UTAE/src/panoptic/paps_loss.py b/examples/UTAE/src/panoptic/paps_loss.py
new file mode 100644
index 000000000..68c60f685
--- /dev/null
+++ b/examples/UTAE/src/panoptic/paps_loss.py
@@ -0,0 +1,266 @@
+"""
+PaPs Implementation (Paddle Version)
+Converted to PaddlePaddle
+"""
+
+import paddle
+import paddle.nn as nn
+from src.panoptic.FocalLoss import FocalLoss
+
+
+class PaPsLoss(nn.Layer):
+ """
+ Loss for training PaPs.
+ Args:
+ l_center (float): Coefficient for the centerness loss (default 1)
+ l_size (float): Coefficient for the size loss (default 1)
+ l_shape (float): Coefficient for the shape loss (default 1)
+ l_class (float): Coefficient for the classification loss (default 1)
+ alpha (float): Parameter for the centerness loss (default 0)
+ beta (float): Parameter for the centerness loss (default 4)
+ gamma (float): Focal exponent for the classification loss (default 0)
+ eps (float): Stability epsilon
+ void_label (int): Label to ignore in the classification loss
+ """
+
+ def __init__(
+ self,
+ l_center=1,
+ l_size=1,
+ l_shape=1,
+ l_class=1,
+ alpha=0,
+ beta=4,
+ gamma=0,
+ eps=1e-8,
+ void_label=None,
+ binary_threshold=0.4,
+ ):
+
+ super(PaPsLoss, self).__init__()
+ self.l_center = l_center
+ self.l_size = l_size
+ self.l_shape = l_shape
+ self.l_class = l_class
+ self.eps = eps
+ self.binary_threshold = binary_threshold
+
+ self.center_loss = CenterLoss(alpha=alpha, beta=beta, eps=eps)
+ self.class_loss = FocalLoss(gamma=gamma, ignore_label=void_label)
+ self.shape_loss = FocalLoss(gamma=0)
+ self.value = (0, 0, 0, 0, 0)
+
+ # Keep track of the predicted confidences and ious between predicted and gt binary masks.
+ # This is usefull for tuning the confidence threshold of the pseudo-nms.
+ self.predicted_confidences = None
+ self.achieved_ious = None
+
+ def forward(self, predictions, target, heatmap_only=False):
+ # Split target tensor - equivalent to torch.split
+ target_splits = paddle.split(target, [1, 1, 1, 2, 1, 1], axis=-1)
+ (
+ target_heatmap,
+ true_instances,
+ zones,
+ size,
+ sem_obj,
+ sem_pix,
+ ) = target_splits
+
+ # Create center mapping dictionary - following original torch logic
+ # center_mask is now always 3D (B, H, W) from PaPs model
+ center_mask = predictions["center_mask"]
+ center_indices = paddle.nonzero(center_mask)
+ center_mapping = {}
+
+ if center_indices.shape[0] > 0:
+ # Create mapping: (batch_idx, height_idx, width_idx) -> center_id
+ for k, (b, i, j) in enumerate(
+ zip(center_indices[:, 0], center_indices[:, 1], center_indices[:, 2])
+ ):
+ center_mapping[(int(b), int(i), int(j))] = k
+
+ loss_center = 0
+ loss_size = 0
+ loss_shape = 0
+ loss_class = 0
+
+ if self.l_center != 0:
+ loss_center = self.center_loss(predictions["heatmap"], target_heatmap)
+
+ if not heatmap_only and predictions["size"].shape[0] != 0:
+ if self.l_size != 0:
+ # Use center indices to extract corresponding sizes
+ if center_indices.shape[0] > 0:
+ # Extract true sizes at center locations
+ # center_indices now has shape (N, 3) with (batch, height, width)
+ batch_ids = center_indices[:, 0]
+ h_ids = center_indices[:, 1]
+ w_ids = center_indices[:, 2]
+
+ # Use gather_nd to extract sizes
+ size_indices = paddle.stack([batch_ids, h_ids, w_ids], axis=1)
+ true_size = paddle.gather_nd(size, size_indices) # (N, 2)
+
+ loss_size = paddle.abs(true_size - predictions["size"]) / (
+ true_size + self.eps
+ )
+ loss_size = loss_size.sum(axis=-1).mean()
+ else:
+ loss_size = paddle.to_tensor(0.0)
+
+ if self.l_class != 0:
+ # Use center indices for semantic object labels
+ if center_indices.shape[0] > 0:
+ # Extract semantic labels at center locations
+ # center_indices now has shape (N, 3) with (batch, height, width)
+ batch_ids = center_indices[:, 0]
+ h_ids = center_indices[:, 1]
+ w_ids = center_indices[:, 2]
+
+ # Use gather_nd to extract semantic labels
+ sem_indices = paddle.stack([batch_ids, h_ids, w_ids], axis=1)
+ sem_labels = paddle.gather_nd(sem_obj, sem_indices) # (N, 1)
+ sem_labels = sem_labels.squeeze(
+ -1
+ ) # Remove last dimension to get (N,)
+
+ loss_class = self.class_loss(
+ predictions["semantic"],
+ sem_labels.astype("int64"),
+ )
+ else:
+ loss_class = paddle.to_tensor(0.0)
+
+ if self.l_shape != 0:
+ confidence_pred = []
+ ious = []
+ flatten_preds = []
+ flatten_target = []
+
+ # Faithful to original PyTorch implementation
+ for b in range(true_instances.shape[0]):
+ instance_mask = true_instances[b]
+ for inst_id in paddle.unique(instance_mask):
+ centers = predictions["center_mask"][b] * (
+ zones[b] == inst_id
+ ).squeeze(
+ -1
+ ) # center matching
+
+ if not centers.any():
+ continue
+
+ # Original PyTorch style: iterate over nonzero positions
+ center_positions = paddle.nonzero(centers)
+ for pos in center_positions:
+ x, y = int(pos[0]), int(pos[1])
+ true_mask = (
+ (instance_mask == inst_id).squeeze(-1).astype("float32")
+ )
+
+ pred_id = center_mapping[(b, int(x), int(y))]
+
+ xtl, ytl, xbr, ybr = predictions["instance_boxes"][pred_id]
+
+ crop_true = true_mask[ytl:ybr, xtl:xbr].reshape([-1, 1])
+ mask = predictions["instance_masks"][pred_id].reshape(
+ [-1, 1]
+ )
+
+ flatten_preds.append(mask)
+ flatten_target.append(crop_true)
+
+ confidence_pred.append(predictions["confidence"][pred_id])
+ bmask = (mask > self.binary_threshold).astype("float32")
+ inter = (bmask * crop_true).sum().astype("float32")
+ union = ((bmask + crop_true) != 0).astype("float32").sum()
+ true_mask[ytl:ybr, xtl:xbr] = 0
+ union = (
+ union + true_mask.sum()
+ ) # parts of shape outside of bbox
+ iou = inter / union
+ if paddle.isnan(iou) or paddle.isinf(iou):
+ iou = paddle.zeros([1], dtype="float32")
+ ious.append(iou)
+
+ if len(flatten_preds) > 0:
+ p = paddle.concat(flatten_preds, axis=0)
+ p = paddle.concat([1 - p, p], axis=1)
+ t = paddle.concat(flatten_target, axis=0).astype("int64")
+ loss_shape = self.shape_loss(p, t)
+
+ self.predicted_confidences = paddle.stack(confidence_pred)
+ self.achieved_ious = paddle.stack(ious).unsqueeze(-1)
+ else:
+ loss_shape = paddle.to_tensor(0.0)
+
+ loss = (
+ self.l_center * loss_center
+ + self.l_size * loss_size
+ + self.l_shape * loss_shape
+ + self.l_class * loss_class
+ )
+
+ self.value = (
+ float(loss_center.detach().cpu().item())
+ if isinstance(loss_center, paddle.Tensor)
+ else loss_center,
+ float(loss_size.detach().cpu().item())
+ if isinstance(loss_size, paddle.Tensor)
+ else loss_size,
+ float(loss_shape.detach().cpu().item())
+ if isinstance(loss_shape, paddle.Tensor)
+ else loss_shape,
+ float(loss_class.detach().cpu().item())
+ if isinstance(loss_class, paddle.Tensor)
+ else loss_class,
+ )
+ return loss
+
+
+class CenterLoss(nn.Layer):
+ """
+ Adapted from the github repo of the CornerNet paper
+ https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py
+ Converted to PaddlePaddle
+ """
+
+ def __init__(self, alpha=0, beta=4, eps=1e-8):
+ super(CenterLoss, self).__init__()
+ self.a = alpha
+ self.b = beta
+ self.eps = eps
+
+ def forward(self, preds, gt):
+ pred = preds.transpose([0, 2, 3, 1]).reshape([-1, preds.shape[1]])
+ g = gt.reshape([-1, preds.shape[1]])
+
+ pos_inds = g == 1
+ neg_inds = g < 1
+ num_pos = pos_inds.astype("float32").sum()
+ loss = 0
+
+ if pos_inds.any():
+ pos_pred = pred[pos_inds]
+ pos_loss = paddle.log(pos_pred + self.eps)
+ pos_loss = pos_loss * paddle.pow(1 - pos_pred, self.a)
+ pos_loss = pos_loss.sum()
+ else:
+ pos_loss = paddle.to_tensor(0.0)
+
+ if neg_inds.any():
+ neg_pred = pred[neg_inds]
+ neg_g = g[neg_inds]
+ neg_loss = paddle.log(1 - neg_pred + self.eps)
+ neg_loss = neg_loss * paddle.pow(neg_pred, self.a)
+ neg_loss = neg_loss * paddle.pow(1 - neg_g, self.b)
+ neg_loss = neg_loss.sum()
+ else:
+ neg_loss = paddle.to_tensor(0.0)
+
+ if not pos_inds.any():
+ loss = loss - neg_loss
+ else:
+ loss = loss - (pos_loss + neg_loss) / num_pos
+ return loss
diff --git a/examples/UTAE/src/utils.py b/examples/UTAE/src/utils.py
new file mode 100644
index 000000000..66ec985fb
--- /dev/null
+++ b/examples/UTAE/src/utils.py
@@ -0,0 +1,152 @@
+"""
+Utility functions (Paddle Version)
+"""
+import collections.abc
+import re
+
+import numpy as np
+import paddle
+
+np_str_obj_array_pattern = re.compile(r"[SaUO]")
+
+"""Pad tensor to target shape for all dimensions"""
+
+
+def pad_tensor(x, target_shape, pad_value=0):
+
+ if len(x.shape) != len(target_shape):
+ raise ValueError(f"Shape mismatch: {x.shape} vs {target_shape}")
+
+ # Check if padding is needed
+ if tuple(x.shape) == tuple(target_shape):
+ return x
+
+ # Calculate padding for each dimension
+ # Paddle padding format: [dim_n_left, dim_n_right, dim_{n-1}_left, dim_{n-1}_right, ...]
+ # For 2D: [dim1_left, dim1_right, dim0_left, dim0_right]
+ # For 4D: [dim3_left, dim3_right, dim2_left, dim2_right, dim1_left, dim1_right, dim0_left, dim0_right]
+
+ pad = []
+ needs_padding = False
+
+ # Build padding list from last dimension to first
+ # But we need to add pairs in the correct order for Paddle
+ pad_pairs = []
+ for i in range(len(x.shape) - 1, -1, -1):
+ pad_size = target_shape[i] - x.shape[i]
+ if pad_size < 0:
+ raise ValueError(
+ f"Target size {target_shape[i]} smaller than current size {x.shape[i]} in dim {i}"
+ )
+
+ if pad_size > 0:
+ needs_padding = True
+
+ # Store [left_pad, right_pad] for this dimension
+ pad_pairs.append([0, pad_size])
+
+ # Reverse the pairs to match Paddle's expected order
+ pad_pairs.reverse()
+
+ # Flatten the pairs into the final pad list
+ for pair in pad_pairs:
+ pad.extend(pair)
+
+ if not needs_padding:
+ return x
+
+ # Apply padding
+ result = paddle.nn.functional.pad(x, pad=pad, value=pad_value)
+
+ # Verify result shape
+ if tuple(result.shape) != tuple(target_shape):
+ # Print debug info for troubleshooting
+ print(
+ f"Debug: input_shape={x.shape}, target_shape={target_shape}, pad={pad}, result_shape={result.shape}"
+ )
+ raise ValueError(f"Padding failed: expected {target_shape}, got {result.shape}")
+
+ return result
+
+
+"""Get the maximum shape across all tensors in batch"""
+
+
+def get_max_shape(batch):
+
+ if not batch:
+ return None
+
+ max_shape = list(batch[0].shape)
+ for tensor in batch[1:]:
+ for i, size in enumerate(tensor.shape):
+ max_shape[i] = max(max_shape[i], size)
+
+ return tuple(max_shape)
+
+
+"""
+Modified default_collate from the official pytorch repo for padding variable length sequences
+Adapted from the original PyTorch implementation
+"""
+
+
+def pad_collate(batch, pad_value=0):
+
+ elem = batch[0]
+ elem_type = type(elem)
+
+ if isinstance(elem, paddle.Tensor):
+ _out = None
+ if len(elem.shape) > 0:
+ # Check if any shapes differ
+ shapes = [e.shape for e in batch]
+ if not all(s == shapes[0] for s in shapes):
+ # Get maximum shape across all dimensions
+ max_shape = get_max_shape(batch)
+ # Pad all tensors to max shape
+ batch = [pad_tensor(e, max_shape, pad_value=pad_value) for e in batch]
+ return paddle.stack(batch, axis=0)
+
+ elif (
+ elem_type.__module__ == "numpy"
+ and elem_type.__name__ != "str_"
+ and elem_type.__name__ != "string_"
+ ):
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError("Format not managed : {}".format(elem.dtype))
+ return pad_collate([paddle.to_tensor(b) for b in batch], pad_value)
+ elif elem.shape == (): # scalars
+ return paddle.to_tensor(batch)
+
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: pad_collate([d[key] for d in batch], pad_value) for key in elem}
+
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
+ return elem_type(
+ *(pad_collate(list(samples), pad_value) for samples in zip(*batch))
+ )
+
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError("each element in list of batch should be of equal size")
+ transposed = list(zip(*batch))
+ return [pad_collate(samples, pad_value) for samples in transposed]
+
+ raise TypeError("Format not managed : {}".format(elem_type))
+
+
+"""
+Set random seed for reproducibility
+"""
+
+
+def set_seed(seed):
+
+ np.random.seed(seed)
+ paddle.seed(seed)
diff --git a/examples/UTAE/test_panoptic.py b/examples/UTAE/test_panoptic.py
new file mode 100644
index 000000000..5ad583d2e
--- /dev/null
+++ b/examples/UTAE/test_panoptic.py
@@ -0,0 +1,367 @@
+"""
+Script for panoptic inference with pre-trained models (Paddle Version)
+Converted to PaddlePaddle
+"""
+import argparse
+import json
+import os
+import pprint
+
+import numpy as np
+import paddle
+from src import model_utils as model_utils
+from src.dataset import PASTIS_Dataset
+from src.model_utils import get_ntrainparams
+from src.panoptic.metrics import PanopticMeter
+from src.panoptic.paps_loss import PaPsLoss
+from src.utils import pad_collate
+
+parser = argparse.ArgumentParser()
+# Model parameters
+parser.add_argument(
+ "--weight_folder",
+ type=str,
+ default="",
+ help="Path to the main folder containing the pre-trained weights",
+)
+parser.add_argument(
+ "--dataset_folder",
+ default="",
+ type=str,
+ help="Path to the folder where the results are saved.",
+)
+parser.add_argument(
+ "--res_dir",
+ default="./inference_paps",
+ type=str,
+ help="Path to directory where results are written.",
+)
+parser.add_argument(
+ "--num_workers", default=4, type=int, help="Number of data loading workers"
+)
+parser.add_argument(
+ "--fold",
+ default=None,
+ type=int,
+ help="Do only one of the five fold (between 1 and 5)",
+)
+parser.add_argument(
+ "--device",
+ default="gpu",
+ type=str,
+ help="Name of device to use for tensor computations (gpu/cpu)",
+)
+parser.add_argument(
+ "--display_step",
+ default=50,
+ type=int,
+ help="Interval in batches between display of training metrics",
+)
+parser.add_argument("--batch_size", default=2, type=int, help="Batch size")
+
+
+def recursive_todevice(x, device):
+ if isinstance(x, paddle.Tensor):
+ return x.cuda() if device == "gpu" else x.cpu()
+ else:
+ return [recursive_todevice(c, device) for c in x]
+
+
+def prepare_output(config):
+ os.makedirs(config.res_dir, exist_ok=True)
+ for fold in range(1, 6):
+ os.makedirs(os.path.join(config.res_dir, "Fold_{}".format(fold)), exist_ok=True)
+
+
+def iterate(
+ model,
+ data_loader,
+ criterion,
+ panoptic_meter,
+ config,
+ optimizer=None,
+ mode="test",
+ device="gpu",
+):
+ """Inference iteration for panoptic segmentation"""
+ loss_meter = 0
+
+ for i, batch in enumerate(data_loader):
+ if device == "gpu":
+ batch = recursive_todevice(batch, device)
+
+ (x, dates), targets = batch
+ targets = targets.astype("float32")
+
+ with paddle.no_grad():
+ # Full panoptic prediction with pseudo-NMS
+ predictions = model(
+ x,
+ batch_positions=dates,
+ zones=targets[:, :, :, 2:3] if config.supmax else None,
+ heatmap_only=False,
+ pseudo_nms=True,
+ )
+
+ # Compute loss (optional for testing)
+ loss = criterion(predictions, targets, heatmap_only=False)
+
+ # Update metrics
+ loss_meter += loss.item()
+
+ # Add predictions to panoptic meter
+ if predictions["pano_semantic"] is not None:
+ panoptic_meter.add(predictions, targets)
+
+ if (i + 1) % config.display_step == 0:
+ SQ, RQ, PQ = panoptic_meter.value()
+ print(
+ f"{mode} - Step [{i+1}/{len(data_loader)}] Loss: {loss_meter/(i+1):.4f} "
+ f"SQ: {SQ*100:.1f} RQ: {RQ*100:.1f} PQ: {PQ*100:.1f}"
+ )
+
+ # Final metrics
+ SQ, RQ, PQ = panoptic_meter.value()
+
+ metrics = {
+ f"{mode}_loss": loss_meter / len(data_loader),
+ f"{mode}_SQ": float(SQ),
+ f"{mode}_RQ": float(RQ),
+ f"{mode}_PQ": float(PQ),
+ }
+
+ return metrics, panoptic_meter.get_table()
+
+
+def save_results(fold, metrics, tables, config):
+ """Save test results"""
+ fold_dir = os.path.join(config.res_dir, f"Fold_{fold}")
+
+ # Save metrics as JSON
+ with open(os.path.join(fold_dir, "test_metrics.json"), "w") as f:
+ json.dump(metrics, f, indent=4)
+
+ # Save detailed tables as numpy binary file (same as PyTorch version)
+ np.save(os.path.join(fold_dir, "test_tables"), tables)
+
+
+def overall_performance(config):
+ """Compute overall performance across all folds"""
+ all_metrics = []
+ all_tables = []
+
+ for fold in range(1, 6):
+ fold_dir = os.path.join(config.res_dir, f"Fold_{fold}")
+
+ # Load metrics
+ metrics_path = os.path.join(fold_dir, "test_metrics.json")
+ if not os.path.exists(metrics_path):
+ continue
+
+ with open(metrics_path, "r") as f:
+ metrics = json.load(f)
+ all_metrics.append(metrics)
+
+ # Load tables (numpy format)
+ tables_path = os.path.join(fold_dir, "test_tables.npy")
+ if os.path.exists(tables_path):
+ tables = np.load(tables_path)
+ all_tables.append(tables)
+
+ if not all_metrics:
+ print("No test results found!")
+ return
+
+ # Compute averages
+ avg_metrics = {}
+ for key in all_metrics[0].keys():
+ if isinstance(all_metrics[0][key], (int, float)):
+ avg_metrics[key] = np.mean([m[key] for m in all_metrics])
+ avg_metrics[key + "_std"] = np.std([m[key] for m in all_metrics])
+
+ # Save overall results
+ with open(os.path.join(config.res_dir, "overall_metrics.json"), "w") as f:
+ json.dump(avg_metrics, f, indent=4)
+
+ print("=== OVERALL PANOPTIC RESULTS ===")
+ print(
+ f"Average Loss: {avg_metrics['test_loss']:.4f} ± {avg_metrics['test_loss_std']:.4f}"
+ )
+ print(
+ f"Average SQ: {avg_metrics['test_SQ']*100:.1f} ± {avg_metrics['test_SQ_std']*100:.1f}"
+ )
+ print(
+ f"Average RQ: {avg_metrics['test_RQ']*100:.1f} ± {avg_metrics['test_RQ_std']*100:.1f}"
+ )
+ print(
+ f"Average PQ: {avg_metrics['test_PQ']*100:.1f} ± {avg_metrics['test_PQ_std']*100:.1f}"
+ )
+
+
+def main(config):
+ fold_sequence = [
+ [[1, 2, 3], [4], [5]],
+ [[2, 3, 4], [5], [1]],
+ [[3, 4, 5], [1], [2]],
+ [[4, 5, 1], [2], [3]],
+ [[5, 1, 2], [3], [4]],
+ ]
+
+ paddle.seed(config.rdm_seed)
+ np.random.seed(config.rdm_seed)
+ prepare_output(config)
+
+ # Set device
+ if config.device == "gpu" and paddle.is_compiled_with_cuda():
+ paddle.device.set_device("gpu")
+ else:
+ paddle.device.set_device("cpu")
+ config.device = "cpu"
+
+ # Create model
+ model = model_utils.get_model(config, mode="panoptic")
+ config.N_params = get_ntrainparams(model)
+ print("TOTAL TRAINABLE PARAMETERS :", config.N_params)
+
+ fold_sequence = (
+ fold_sequence if config.fold is None else [fold_sequence[config.fold - 1]]
+ )
+
+ for fold, (train_folds, val_fold, test_fold) in enumerate(fold_sequence):
+ if config.fold is not None:
+ fold = config.fold - 1
+
+ print(f"\n=== Testing Fold {fold + 1} ===")
+
+ # Dataset definition
+ dt_test = PASTIS_Dataset(
+ folder=config.dataset_folder,
+ norm=True,
+ reference_date=config.ref_date,
+ mono_date=config.mono_date,
+ target="instance", # Important: use instance target for panoptic
+ sats=["S2"],
+ folds=test_fold,
+ )
+ collate_fn = lambda x: pad_collate(x, pad_value=config.pad_value)
+ test_loader = paddle.io.DataLoader(
+ dt_test,
+ batch_size=config.batch_size,
+ shuffle=False,
+ num_workers=config.num_workers,
+ collate_fn=collate_fn,
+ )
+
+ print(f"Test samples: {len(dt_test)}")
+
+ # Load weights
+ weight_path = config.weight_folder
+
+ if not os.path.exists(weight_path):
+ print(f"Warning: Weight file not found at {weight_path}")
+ continue
+
+ checkpoint = paddle.load(weight_path)
+ if "state_dict" in checkpoint:
+ model.set_state_dict(checkpoint["state_dict"])
+ else:
+ model.set_state_dict(checkpoint)
+ print(f"Loaded weights from {weight_path}")
+
+ # Loss and metrics
+ criterion = PaPsLoss(
+ l_center=config.l_center,
+ l_size=config.l_size,
+ l_shape=config.l_shape,
+ l_class=config.l_class,
+ beta=config.beta,
+ )
+
+ panoptic_meter = PanopticMeter(
+ num_classes=config.num_classes,
+ void_label=config.ignore_index if config.ignore_index != -1 else None,
+ )
+
+ # Inference
+ print("Testing . . .")
+ model.eval()
+ test_metrics, tables = iterate(
+ model,
+ data_loader=test_loader,
+ criterion=criterion,
+ panoptic_meter=panoptic_meter,
+ config=config,
+ optimizer=None,
+ mode="test",
+ device=config.device,
+ )
+
+ print(
+ "Loss {:.4f}, SQ {:.1f}, RQ {:.1f}, PQ {:.1f}".format(
+ test_metrics["test_loss"],
+ test_metrics["test_SQ"] * 100,
+ test_metrics["test_RQ"] * 100,
+ test_metrics["test_PQ"] * 100,
+ )
+ )
+ # print("test_metrics_SQ : ",test_metrics['test_SQ'])
+ # print("test_metrics_RQ : ",test_metrics['test_RQ'])
+ # print("test_metrics_PQ : ",test_metrics['test_PQ'])
+ save_results(fold + 1, test_metrics, tables, config)
+
+ if config.fold is None:
+ overall_performance(config)
+
+
+if __name__ == "__main__":
+ test_config = parser.parse_args()
+
+ # Try to load config from conf.json if it exists, otherwise use defaults
+ conf_path = os.path.join(test_config.weight_folder, "conf.json")
+ if os.path.exists(conf_path):
+ with open(conf_path) as file:
+ model_config = json.loads(file.read())
+ config = {**model_config, **vars(test_config)}
+ else:
+ print("Warning: conf.json not found, using test script parameters only")
+ # Set default model parameters for panoptic
+ default_config = {
+ "backbone": "utae",
+ "encoder_widths": "[64,64,64,128]",
+ "decoder_widths": "[32,32,64,128]",
+ "out_conv": "[32, 20]",
+ "str_conv_k": 4,
+ "str_conv_s": 2,
+ "str_conv_p": 1,
+ "agg_mode": "att_group",
+ "encoder_norm": "group",
+ "n_head": 16,
+ "d_model": 256,
+ "d_k": 4,
+ "num_classes": 20,
+ "ignore_index": -1,
+ "pad_value": 0,
+ "padding_mode": "reflect",
+ "ref_date": "2018-09-01",
+ "mono_date": None,
+ "rdm_seed": 1,
+ "supmax": True,
+ # PaPs specific parameters
+ "shape_size": 16,
+ "mask_conv": True,
+ "min_confidence": 0.2,
+ "min_remain": 0.5,
+ "mask_threshold": 0.4,
+ "l_center": 1,
+ "l_size": 1,
+ "l_shape": 1,
+ "l_class": 1,
+ "beta": 4,
+ }
+ config = {**default_config, **vars(test_config)}
+
+ config = argparse.Namespace(**config)
+ config.fold = test_config.fold
+
+ pprint.pprint(vars(config))
+ main(config)
diff --git a/examples/UTAE/test_semantic.py b/examples/UTAE/test_semantic.py
new file mode 100644
index 000000000..a5a2057fe
--- /dev/null
+++ b/examples/UTAE/test_semantic.py
@@ -0,0 +1,238 @@
+import argparse
+import glob
+import json
+import os
+import pprint
+import re
+
+import numpy as np
+import paddle
+import paddle.io as data
+import paddle.nn as nn
+from src import model_utils
+from src import utils
+from src.dataset import PASTIS_Dataset
+
+
+def prepare_output(res_dir: str):
+ os.makedirs(res_dir, exist_ok=True)
+ for k in range(1, 6):
+ os.makedirs(os.path.join(res_dir, f"Fold_{k}"), exist_ok=True)
+
+
+def _auto_pick_ckpt(fold_dir: str) -> str:
+
+ pref = os.path.join(fold_dir, "model.pdparams")
+ if os.path.isfile(pref):
+ return pref
+ cands = glob.glob(os.path.join(fold_dir, "*.pdparams"))
+ if not cands:
+ raise ValueError(f"No .pdparams in {fold_dir}")
+
+ def score(p):
+ m = re.search(r"miou_([0-9.]+)\.pdparams$", p)
+ return float(m.group(1)) if m else -1.0
+
+ return max(cands, key=lambda p: (score(p), os.path.getmtime(p)))
+
+
+def iterate_eval(model, data_loader, criterion, num_classes, ignore_index):
+ # IoU 兼容两种路径
+ try:
+ from src.learning.miou import IoU
+ except Exception:
+ from src.miou import IoU
+
+ iou_meter = IoU(num_classes=num_classes, ignore_index=ignore_index, cm_device="cpu")
+ loss_sum, nb = 0.0, 0
+ model.eval()
+ with paddle.no_grad():
+ for (x, dates), y in data_loader:
+ out = model(x, batch_positions=dates) # [B,C,H,W]
+ B, C, H, W = out.shape
+ logits = out.transpose([0, 2, 3, 1]).reshape([-1, C])
+ target = y.reshape([-1]).astype("int64")
+ loss = criterion(logits, target)
+ loss_sum += float(loss.numpy())
+ nb += 1
+ pred = nn.functional.softmax(out, axis=1).argmax(axis=1)
+ iou_meter.add(pred, y)
+ miou, acc = iou_meter.get_miou_acc()
+ return {
+ "test_loss": loss_sum / max(1, nb),
+ "test_accuracy": float(acc),
+ "test_IoU": float(miou),
+ }, iou_meter.confusion_matrix
+
+
+def _to_eval_string(v, fallback: str):
+
+ if v is None:
+ return fallback
+ if isinstance(v, (list, tuple)):
+ return "[" + ", ".join(str(x) for x in v) + "]"
+ return str(v)
+
+
+def main(cfg):
+ # 设备兜底
+ if cfg.device == "gpu" and not paddle.is_compiled_with_cuda():
+ print("⚠️ 当前环境未编译 CUDA,自动切到 CPU。")
+ cfg.device = "cpu"
+ paddle.set_device(cfg.device)
+ np.random.seed(cfg.rdm_seed)
+ paddle.seed(cfg.rdm_seed)
+ prepare_output(cfg.res_dir)
+ cfg.encoder_widths = _to_eval_string(
+ getattr(cfg, "encoder_widths", None), "[64,64,64,128]"
+ )
+ cfg.decoder_widths = _to_eval_string(
+ getattr(cfg, "decoder_widths", None), "[32,32,64,128]"
+ )
+ cfg.out_conv = _to_eval_string(getattr(cfg, "out_conv", None), "[32, 20]")
+
+ # 构建模型
+ model = model_utils.get_model(cfg, mode="semantic")
+ print(model)
+ print("TOTAL TRAINABLE PARAMETERS :", model_utils.get_ntrainparams(model))
+
+ # 折序列(与训练一致)
+ fold_sequence = [
+ [[1, 2, 3], [4], [5]],
+ [[2, 3, 4], [5], [1]],
+ [[3, 4, 5], [1], [2]],
+ [[4, 5, 1], [2], [3]],
+ [[5, 1, 2], [3], [4]],
+ ]
+
+ # —— 如果指定 --weight_file,仅评一个 fold —— #
+ if cfg.weight_file:
+ if cfg.fold is not None:
+ run_fold = cfg.fold
+ else:
+ m = re.search(r"[\\/](?:Fold_|fold_)(\d)[\\/]", cfg.weight_file)
+ run_fold = int(m.group(1)) if m else 1
+ seq = [fold_sequence[run_fold - 1]]
+ print(f"Single-fold mode (from weight_file): fold={run_fold}")
+ else:
+ seq = fold_sequence if cfg.fold is None else [fold_sequence[cfg.fold - 1]]
+
+ for idx, (_, _, test_fold) in enumerate(seq):
+ fold_id = (
+ (cfg.fold if cfg.fold is not None else (idx + 1))
+ if not cfg.weight_file
+ else run_fold
+ )
+
+ # 数据
+ ds = PASTIS_Dataset(
+ folder=cfg.dataset_folder,
+ norm=True,
+ reference_date=cfg.ref_date,
+ mono_date=cfg.mono_date,
+ target="semantic",
+ sats=["S2"],
+ folds=test_fold,
+ )
+ collate = lambda x: utils.pad_collate(x, pad_value=cfg.pad_value)
+ loader = data.DataLoader(
+ ds,
+ batch_size=cfg.batch_size,
+ shuffle=False,
+ drop_last=False,
+ collate_fn=collate,
+ num_workers=cfg.num_workers,
+ )
+ print(
+ f"#test samples: {len(ds)}, batch_size: {cfg.batch_size}, #batches: {len(loader)}"
+ )
+
+ # 权重
+ if cfg.weight_file:
+ wpath = cfg.weight_file
+ else:
+ if not cfg.weight_folder:
+ raise ValueError("Provide --weight_file or --weight_folder")
+ fold_dir = os.path.join(cfg.weight_folder, f"Fold_{fold_id}")
+ if not os.path.isdir(fold_dir):
+ raise ValueError(f"Fold dir not found: {fold_dir}")
+ wpath = _auto_pick_ckpt(fold_dir)
+ print(f"Loading weights: {wpath}")
+ sd = paddle.load(wpath)
+ state = sd["state_dict"] if isinstance(sd, dict) and "state_dict" in sd else sd
+ model.set_state_dict(state)
+
+ # 损失
+ w = paddle.ones([cfg.num_classes], dtype="float32")
+ if 0 <= cfg.ignore_index < cfg.num_classes:
+ w[cfg.ignore_index] = 0
+ criterion = nn.CrossEntropyLoss(weight=w)
+
+ # 推理
+ print("Testing ...")
+ metrics, cm = iterate_eval(
+ model, loader, criterion, cfg.num_classes, cfg.ignore_index
+ )
+ print(
+ f"[Fold {fold_id}] Loss {metrics['test_loss']:.4f}, "
+ f"Acc {metrics['test_accuracy']:.2f}, IoU {metrics['test_IoU']:.4f}"
+ )
+
+ # 保存
+ outd = os.path.join(cfg.res_dir, f"Fold_{fold_id}")
+ os.makedirs(outd, exist_ok=True)
+ with open(os.path.join(outd, "test_metrics.json"), "w", encoding="utf-8") as f:
+ json.dump(metrics, f, indent=2, ensure_ascii=False)
+ np.save(os.path.join(outd, "confusion_matrix.npy"), cm)
+ print(f"Saved metrics and confusion matrix to {outd}")
+
+ # --weight_file 触发的单折模式:跑完即结束
+ if cfg.weight_file:
+ break
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # 运行 / 数据
+ parser.add_argument("--dataset_folder", type=str, default="", help="PASTIS 根目录")
+ parser.add_argument("--res_dir", type=str, default="./inference_utae")
+ parser.add_argument("--fold", type=int, default=None, help="1..5;指定时只评该折")
+ parser.add_argument("--device", type=str, default="gpu", choices=["gpu", "cpu"])
+ parser.add_argument("--num_workers", type=int, default=0)
+
+ # 权重
+ parser.add_argument(
+ "--weight_folder", type=str, default="", help="results 根目录,自动为每折挑 ckpt"
+ )
+ parser.add_argument(
+ "--weight_file", type=str, default="", help="单个 .pdparams 路径(只评一个 fold)"
+ )
+
+ # 模型结构(与训练一致;可按需覆盖)
+ parser.add_argument("--model", type=str, default="utae")
+ parser.add_argument("--encoder_widths", type=str, default="[64,64,64,128]")
+ parser.add_argument("--decoder_widths", type=str, default="[32,32,64,128]")
+ parser.add_argument("--out_conv", type=str, default="[32, 20]")
+ parser.add_argument("--str_conv_k", type=int, default=4)
+ parser.add_argument("--str_conv_s", type=int, default=2)
+ parser.add_argument("--str_conv_p", type=int, default=1)
+ parser.add_argument("--agg_mode", type=str, default="att_group")
+ parser.add_argument("--encoder_norm", type=str, default="group")
+ parser.add_argument("--n_head", type=int, default=16)
+ parser.add_argument("--d_model", type=int, default=256)
+ parser.add_argument("--d_k", type=int, default=4)
+ parser.add_argument("--padding_mode", type=str, default="reflect")
+
+ # 标签 / 批量
+ parser.add_argument("--num_classes", type=int, default=20)
+ parser.add_argument("--ignore_index", type=int, default=-1)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--ref_date", type=str, default="2018-09-01")
+ parser.add_argument("--pad_value", type=float, default=0.0)
+ parser.add_argument("--rdm_seed", type=int, default=1)
+ parser.add_argument("--mono_date", type=str, default=None)
+
+ cfg = parser.parse_args()
+ pprint.pprint(vars(cfg))
+ main(cfg)
diff --git a/examples/UTAE/train_panoptic.py b/examples/UTAE/train_panoptic.py
new file mode 100644
index 000000000..4bc550756
--- /dev/null
+++ b/examples/UTAE/train_panoptic.py
@@ -0,0 +1,481 @@
+"""
+Main script for panoptic experiments (Paddle Version)
+Converted to PaddlePaddle
+"""
+import argparse
+import json
+import os
+import pprint
+import time
+
+import numpy as np
+import paddle
+from src import model_utils as model_utils
+from src.dataset import PASTIS_Dataset
+from src.learning.weight_init import weight_init
+from src.model_utils import get_ntrainparams
+from src.panoptic.metrics import PanopticMeter
+from src.panoptic.paps_loss import PaPsLoss
+from src.utils import pad_collate
+
+parser = argparse.ArgumentParser()
+# PaPs Parameters
+## Architecture Hyperparameters
+parser.add_argument("--shape_size", default=16, type=int, help="Shape size for PaPs")
+parser.add_argument(
+ "--no_mask_conv",
+ dest="mask_conv",
+ action="store_false",
+ help="With this flag no residual CNN is used after combination of global saliency and local shape.",
+)
+parser.add_argument(
+ "--backbone",
+ default="utae",
+ type=str,
+ help="Backbone encoder for PaPs (utae or uconvlstm)",
+)
+
+## Losses & metrics
+parser.add_argument(
+ "--l_center", default=1, type=float, help="Coefficient for centerness loss"
+)
+parser.add_argument("--l_size", default=1, type=float, help="Coefficient for size loss")
+parser.add_argument(
+ "--l_shape", default=0, type=float, help="Coefficient for shape loss"
+)
+parser.add_argument(
+ "--l_class", default=1, type=float, help="Coefficient for class loss"
+)
+parser.add_argument(
+ "--beta", default=4, type=float, help="Beta parameter for centerness loss"
+)
+parser.add_argument(
+ "--no_autotune",
+ dest="autotune",
+ action="store_false",
+ help="If this flag is used the confidence threshold for the pseudo-nms will NOT be tuned automatically on the validation set",
+)
+parser.add_argument(
+ "--no_supmax",
+ dest="supmax",
+ action="store_false",
+ help="If this flag is used, ALL local maxima are supervised (and not just the more confident center per ground truth object)",
+)
+parser.add_argument(
+ "--warmup",
+ default=5,
+ type=int,
+ help="Number of epochs to do with only the centerness loss as supervision.",
+)
+parser.add_argument(
+ "--val_metrics_only",
+ dest="val_metrics_only",
+ action="store_true",
+ help="If true, panoptic metrics are computed only on validation and test epochs.",
+)
+parser.add_argument(
+ "--val_every",
+ default=5,
+ type=int,
+ help="Interval in epochs between two validation steps.",
+)
+parser.add_argument(
+ "--val_after",
+ default=0,
+ type=int,
+ help="Do validation only after that many epochs.",
+)
+
+## Thresholds
+parser.add_argument(
+ "--min_remain",
+ default=0.5,
+ type=float,
+ help="Minimum remain fraction for the pseudo-nms.",
+)
+parser.add_argument(
+ "--mask_threshold",
+ default=0.4,
+ type=float,
+ help="Binary threshold for instance masks",
+)
+parser.add_argument(
+ "--min_confidence",
+ default=0.2,
+ type=float,
+ help="Minimum confidence threshold for pseudo-nms",
+)
+
+# U-TAE Hyperparameters (if using U-TAE backbone)
+parser.add_argument("--encoder_widths", default="[64,64,64,128]", type=str)
+parser.add_argument("--decoder_widths", default="[32,32,64,128]", type=str)
+parser.add_argument("--out_conv", default="[32, 20]")
+parser.add_argument("--str_conv_k", default=4, type=int)
+parser.add_argument("--str_conv_s", default=2, type=int)
+parser.add_argument("--str_conv_p", default=1, type=int)
+parser.add_argument("--agg_mode", default="att_group", type=str)
+parser.add_argument("--encoder_norm", default="group", type=str)
+parser.add_argument("--n_head", default=16, type=int)
+parser.add_argument("--d_model", default=256, type=int)
+parser.add_argument("--d_k", default=4, type=int)
+
+# Set-up parameters
+parser.add_argument(
+ "--dataset_folder",
+ default="/home/aistudio/PASTIS",
+ type=str,
+ help="Path to the folder where the results are saved.",
+)
+parser.add_argument(
+ "--res_dir",
+ default="./results",
+ help="Path to the folder where the results should be stored",
+)
+parser.add_argument(
+ "--num_workers", default=8, type=int, help="Number of data loading workers"
+)
+parser.add_argument("--rdm_seed", default=1, type=int, help="Random seed")
+parser.add_argument(
+ "--device",
+ default="gpu",
+ type=str,
+ help="Name of device to use for tensor computations (gpu/cpu)",
+)
+parser.add_argument(
+ "--display_step",
+ default=50,
+ type=int,
+ help="Interval in batches between display of training metrics",
+)
+parser.add_argument(
+ "--cache",
+ dest="cache",
+ action="store_true",
+ help="If specified, the whole dataset is kept in RAM",
+)
+
+# Training parameters
+parser.add_argument("--epochs", default=100, type=int, help="Number of epochs per fold")
+parser.add_argument("--batch_size", default=4, type=int, help="Batch size")
+parser.add_argument("--lr", default=0.01, type=float, help="Learning rate")
+parser.add_argument("--mono_date", default=None, type=str)
+parser.add_argument("--ref_date", default="2018-09-01", type=str)
+parser.add_argument(
+ "--fold",
+ default=None,
+ type=int,
+ help="Do only one of the five fold (between 1 and 5)",
+)
+parser.add_argument("--num_classes", default=20, type=int)
+parser.add_argument("--ignore_index", default=-1, type=int)
+parser.add_argument("--pad_value", default=0, type=float)
+parser.add_argument("--padding_mode", default="reflect", type=str)
+
+
+def recursive_todevice(x, device):
+ if isinstance(x, paddle.Tensor):
+ return x.cuda() if device == "gpu" else x.cpu()
+ else:
+ return [recursive_todevice(c, device) for c in x]
+
+
+def prepare_output(config):
+ os.makedirs(config.res_dir, exist_ok=True)
+ for fold in range(1, 6):
+ os.makedirs(os.path.join(config.res_dir, "Fold_{}".format(fold)), exist_ok=True)
+
+
+def checkpoint(log, config):
+ with open(os.path.join(config.res_dir, "trainlog.json"), "w") as outfile:
+ json.dump(log, outfile, indent=4)
+
+
+def save_results(metrics, config):
+ with open(os.path.join(config.res_dir, "test_metrics.json"), "w") as outfile:
+ json.dump(metrics, outfile, indent=4)
+
+
+def get_model(config):
+ """Create PaPs model with specified backbone"""
+ if config.backbone == "utae":
+ from src.backbones.utae import UTAE
+ from src.panoptic.paps import PaPs
+
+ encoder = UTAE(
+ input_dim=10, # PASTIS has 10 spectral bands
+ encoder_widths=eval(config.encoder_widths),
+ decoder_widths=eval(config.decoder_widths),
+ out_conv=eval(config.out_conv),
+ str_conv_k=config.str_conv_k,
+ str_conv_s=config.str_conv_s,
+ str_conv_p=config.str_conv_p,
+ agg_mode=config.agg_mode,
+ encoder_norm=config.encoder_norm,
+ n_head=config.n_head,
+ d_model=config.d_model,
+ d_k=config.d_k,
+ encoder=True, # Important: set to True for PaPs
+ return_maps=True, # Important: return feature maps
+ pad_value=config.pad_value,
+ padding_mode=config.padding_mode,
+ )
+
+ model = PaPs(
+ encoder=encoder,
+ num_classes=config.num_classes,
+ shape_size=config.shape_size,
+ mask_conv=config.mask_conv,
+ min_confidence=config.min_confidence,
+ min_remain=config.min_remain,
+ mask_threshold=config.mask_threshold,
+ )
+ else:
+ raise NotImplementedError(f"Backbone {config.backbone} not implemented yet")
+
+ return model
+
+
+def iterate(
+ model,
+ data_loader,
+ criterion,
+ panoptic_meter,
+ config,
+ optimizer=None,
+ mode="train",
+ device="gpu",
+):
+ loss_meter = 0
+ batch_count = 0
+ t_start = time.time()
+
+ for i, batch in enumerate(data_loader):
+ if device == "gpu":
+ batch = recursive_todevice(batch, device)
+
+ (x, dates), targets = batch
+ targets = targets.astype("float32")
+
+ # Forward pass
+ if mode != "train":
+ with paddle.no_grad():
+ heatmap_only = mode == "train" and config.warmup > 0
+ predictions = model(
+ x,
+ batch_positions=dates,
+ zones=targets[:, :, :, 2:3] if config.supmax else None,
+ heatmap_only=heatmap_only,
+ pseudo_nms=True,
+ ) # Enable pseudo_nms for testing
+ else:
+ heatmap_only = i < config.warmup * len(data_loader) // config.epochs
+ predictions = model(
+ x,
+ batch_positions=dates,
+ zones=targets[:, :, :, 2:3] if config.supmax else None,
+ heatmap_only=heatmap_only,
+ pseudo_nms=True,
+ ) # Enable pseudo_nms for testing
+
+ # Compute loss
+ loss = criterion(predictions, targets, heatmap_only=heatmap_only)
+
+ if mode == "train":
+ optimizer.clear_grad()
+ loss.backward()
+ optimizer.step()
+
+ # Update metrics
+ loss_meter += loss.item()
+ batch_count += 1
+
+ # Add to panoptic meter (if not warmup)
+ if not heatmap_only and not config.val_metrics_only:
+ # Check if we have panoptic predictions (minimal debug)
+ if predictions["pano_semantic"] is not None:
+ panoptic_meter.add(predictions, targets)
+ else:
+ # If no panoptic predictions, skip adding to meter
+ pass
+
+ if (i + 1) % config.display_step == 0:
+ SQ, RQ, PQ = panoptic_meter.value()
+ print(
+ f"{mode} - Step [{i+1}/{len(data_loader)}] Loss: {loss_meter/batch_count:.4f} "
+ f"SQ: {SQ*100:.1f} RQ: {RQ*100:.1f} PQ: {PQ*100:.1f}"
+ )
+
+ t_end = time.time()
+ total_time = t_end - t_start
+
+ # Final metrics
+ SQ, RQ, PQ = panoptic_meter.value()
+ avg_loss = loss_meter / batch_count if batch_count > 0 else 0
+
+ return avg_loss, SQ.item(), RQ.item(), PQ.item(), total_time
+
+
+def main(config):
+ paddle.seed(config.rdm_seed)
+ np.random.seed(config.rdm_seed)
+
+ prepare_output(config)
+
+ # Save configuration for testing
+ with open(os.path.join(config.res_dir, "conf.json"), "w") as f:
+ json.dump(vars(config), f, indent=4)
+
+ if config.fold is not None:
+ folds = [config.fold]
+ else:
+ folds = [1, 2, 3, 4, 5]
+
+ for fold in folds:
+ print(f"Starting fold {fold}")
+
+ # Dataset definition
+ dt_args = dict(
+ folder=config.dataset_folder,
+ norm=True,
+ reference_date=config.ref_date,
+ mono_date=config.mono_date,
+ target="instance", # Important: use instance target for panoptic
+ sats=["S2"],
+ )
+
+ # 5-fold split
+ train_folds = [f for f in [1, 2, 3, 4, 5] if f != fold]
+ val_fold = [fold]
+ _test_fold = [fold] # Same as validation for now
+
+ dt_train = PASTIS_Dataset(**dt_args, folds=train_folds, cache=config.cache)
+ dt_val = PASTIS_Dataset(**dt_args, folds=val_fold, cache=config.cache)
+
+ print(f"Train samples: {len(dt_train)}, Val samples: {len(dt_val)}")
+
+ collate_fn = lambda x: pad_collate(x, pad_value=config.pad_value)
+ train_loader = paddle.io.DataLoader(
+ dt_train,
+ batch_size=config.batch_size,
+ shuffle=True,
+ num_workers=config.num_workers,
+ collate_fn=collate_fn,
+ )
+ val_loader = paddle.io.DataLoader(
+ dt_val,
+ batch_size=config.batch_size,
+ shuffle=False,
+ num_workers=config.num_workers,
+ collate_fn=collate_fn,
+ )
+
+ # Model definition
+ model = get_model(config)
+ model.apply(weight_init)
+ config.N_params = get_ntrainparams(model)
+
+ if config.device == "gpu":
+ # Paddle automatically uses GPU when available, no need for explicit .cuda()
+ pass
+
+ print(f"Model {config.backbone} - {config.N_params} trainable parameters")
+
+ # Loss and optimizer
+ criterion = PaPsLoss(
+ l_center=config.l_center,
+ l_size=config.l_size,
+ l_shape=config.l_shape, # Re-enable shape loss
+ l_class=config.l_class,
+ beta=config.beta,
+ )
+ optimizer = paddle.optimizer.Adam(
+ parameters=model.parameters(), learning_rate=config.lr
+ )
+
+ # Training
+ trainlog = {}
+ best_PQ = 0
+
+ for epoch in range(1, config.epochs + 1):
+ print(f"Epoch {epoch}/{config.epochs}")
+
+ model.train()
+ train_panoptic_meter = PanopticMeter(
+ num_classes=config.num_classes,
+ void_label=config.ignore_index if config.ignore_index != -1 else None,
+ )
+ train_loss, train_SQ, train_RQ, train_PQ, train_time = iterate(
+ model,
+ train_loader,
+ criterion,
+ train_panoptic_meter,
+ config,
+ optimizer,
+ "train",
+ config.device,
+ )
+
+ if epoch % config.val_every == 0 and epoch > config.val_after:
+ model.eval()
+ val_panoptic_meter = PanopticMeter(
+ num_classes=config.num_classes,
+ void_label=config.ignore_index
+ if config.ignore_index != -1
+ else None,
+ )
+ val_loss, val_SQ, val_RQ, val_PQ, val_time = iterate(
+ model,
+ val_loader,
+ criterion,
+ val_panoptic_meter,
+ config,
+ mode="val",
+ device=config.device,
+ )
+
+ print(
+ f"Train - Loss: {train_loss:.4f}, SQ: {train_SQ*100:.1f}, RQ: {train_RQ*100:.1f}, PQ: {train_PQ*100:.1f}"
+ )
+ print(
+ f"Val - Loss: {val_loss:.4f}, SQ: {val_SQ*100:.1f}, RQ: {val_RQ*100:.1f}, PQ: {val_PQ*100:.1f}"
+ )
+
+ trainlog[epoch] = {
+ "train_loss": train_loss,
+ "train_SQ": train_SQ,
+ "train_RQ": train_RQ,
+ "train_PQ": train_PQ,
+ "val_loss": val_loss,
+ "val_SQ": val_SQ,
+ "val_RQ": val_RQ,
+ "val_PQ": val_PQ,
+ }
+
+ checkpoint(trainlog, config)
+
+ if val_PQ >= best_PQ:
+ best_PQ = val_PQ
+ paddle.save(
+ {
+ "epoch": epoch,
+ "state_dict": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ },
+ os.path.join(config.res_dir, f"Fold_{fold}", "model.pdparams"),
+ )
+ else:
+ trainlog[epoch] = {
+ "train_loss": train_loss,
+ "train_SQ": train_SQ,
+ "train_RQ": train_RQ,
+ "train_PQ": train_PQ,
+ }
+ checkpoint(trainlog, config)
+
+ print(f"Fold {fold} completed. Best PQ: {best_PQ:.4f}")
+
+
+if __name__ == "__main__":
+ config = parser.parse_args()
+ pprint.pprint(vars(config))
+ main(config)
diff --git a/examples/UTAE/train_semantic.py b/examples/UTAE/train_semantic.py
new file mode 100644
index 000000000..696cc2bbe
--- /dev/null
+++ b/examples/UTAE/train_semantic.py
@@ -0,0 +1,340 @@
+import argparse
+import json
+import os
+import pprint
+import time
+from typing import Tuple
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from src import model_utils
+from src import utils
+from src.dataset import PASTIS_Dataset
+
+# ---------- Args ----------
+parser = argparse.ArgumentParser()
+
+# Model & arch
+parser.add_argument(
+ "--model",
+ default="utae",
+ type=str,
+ help="utae/unet3d/fpn/convlstm/convgru/uconvlstm/buconvlstm",
+)
+parser.add_argument("--encoder_widths", default="[64,64,64,128]", type=str)
+parser.add_argument("--decoder_widths", default="[32,32,64,128]", type=str)
+parser.add_argument("--out_conv", default="[32, 20]", type=str)
+parser.add_argument("--str_conv_k", default=4, type=int)
+parser.add_argument("--str_conv_s", default=2, type=int)
+parser.add_argument("--str_conv_p", default=1, type=int)
+parser.add_argument("--agg_mode", default="att_group", type=str)
+parser.add_argument("--encoder_norm", default="group", type=str)
+parser.add_argument("--n_head", default=16, type=int)
+parser.add_argument("--d_model", default=256, type=int)
+parser.add_argument("--d_k", default=4, type=int)
+parser.add_argument("--padding_mode", default="reflect", type=str)
+
+# Data / runtime
+parser.add_argument("--dataset_folder", default="", type=str, help="PASTIS root")
+parser.add_argument("--res_dir", default="./results", type=str, help="Output dir")
+parser.add_argument("--num_workers", default=8, type=int)
+parser.add_argument("--rdm_seed", default=1, type=int)
+parser.add_argument("--device", default="gpu", type=str, help="gpu/cpu")
+parser.add_argument("--display_step", default=50, type=int)
+parser.add_argument("--cache", action="store_true", help="Keep dataset in RAM")
+parser.add_argument("--mono_date", default=None, type=str)
+parser.add_argument("--ref_date", default="2018-09-01", type=str)
+
+# Training
+parser.add_argument("--epochs", default=100, type=int)
+parser.add_argument("--batch_size", default=4, type=int)
+parser.add_argument("--lr", default=1e-3, type=float)
+parser.add_argument(
+ "--lr_decay", default="none", type=str, choices=["none", "step", "cosine"]
+)
+parser.add_argument("--lr_step_size", default=20, type=int)
+parser.add_argument("--lr_gamma", default=0.5, type=float)
+parser.add_argument("--grad_clip", default=0.0, type=float, help="0 to disable")
+parser.add_argument("--amp", action="store_true", help="Enable mixed precision")
+parser.add_argument(
+ "--early_stopping", default=0, type=int, help="Patience on val mIoU (0=disabled)"
+)
+parser.add_argument("--fold", default=None, type=int, help="1..5 for single fold")
+parser.add_argument("--val_every", default=1, type=int)
+
+# Labels
+parser.add_argument("--num_classes", default=20, type=int)
+parser.add_argument("--ignore_index", default=-1, type=int)
+parser.add_argument("--pad_value", default=0, type=float)
+
+
+# ---------- Helpers ----------
+def set_seed(seed: int):
+ np.random.seed(seed)
+ paddle.seed(seed)
+
+
+def prepare_output(config):
+ os.makedirs(config.res_dir, exist_ok=True)
+ for f in range(1, 6):
+ os.makedirs(os.path.join(config.res_dir, f"Fold_{f}"), exist_ok=True)
+
+
+def save_conf(config):
+ path = os.path.join(config.res_dir, "conf.json")
+ with open(path, "w") as f:
+ json.dump(vars(config), f, indent=2)
+ # also copy into each fold directory for convenience
+ for f in range(1, 6):
+ with open(os.path.join(config.res_dir, f"Fold_{f}", "conf.json"), "w") as g:
+ json.dump(vars(config), g, indent=2)
+
+
+def save_trainlog(trainlog: dict, config):
+ with open(os.path.join(config.res_dir, "trainlog.json"), "w") as f:
+ json.dump(trainlog, f, indent=2)
+
+
+def save_metrics(metrics: dict, config):
+ with open(os.path.join(config.res_dir, "test_metrics.json"), "w") as f:
+ json.dump(metrics, f, indent=2)
+
+
+def get_iou_class():
+ try:
+ from src.learning.miou import IoU
+ except Exception:
+ from src.miou import IoU
+ return IoU
+
+
+def iterate(
+ model, loader, criterion, device="gpu", optimizer=None, display_step=50
+) -> Tuple[float, float, float]:
+ IoU = get_iou_class()
+ iou_meter = IoU(
+ num_classes=config.num_classes,
+ ignore_index=config.ignore_index,
+ cm_device="cpu",
+ )
+ loss_sum = 0.0
+ n_batches = 0
+
+ model.train() if optimizer is not None else model.eval()
+ scaler = None
+ if optimizer is not None and config.amp:
+ scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
+
+ for i, batch in enumerate(loader):
+ (x, dates), y = batch
+ if optimizer is not None:
+ if config.amp:
+ with paddle.amp.auto_cast():
+ out = model(x, batch_positions=dates)
+ B, C, H, W = out.shape
+ out_reshaped = out.transpose([0, 2, 3, 1]).reshape([-1, C])
+ y_reshaped = y.reshape([-1]).astype("int64")
+ loss = criterion(out_reshaped, y_reshaped)
+ scaler.scale(loss).backward()
+ scaler.minimize(optimizer, loss)
+ optimizer.clear_grad()
+ else:
+ out = model(x, batch_positions=dates)
+ B, C, H, W = out.shape
+ out_reshaped = out.transpose([0, 2, 3, 1]).reshape([-1, C])
+ y_reshaped = y.reshape([-1]).astype("int64")
+ loss = criterion(out_reshaped, y_reshaped)
+ loss.backward()
+ optimizer.step()
+ optimizer.clear_grad()
+ else:
+ with paddle.no_grad():
+ out = model(x, batch_positions=dates)
+ B, C, H, W = out.shape
+ out_reshaped = out.transpose([0, 2, 3, 1]).reshape([-1, C])
+ y_reshaped = y.reshape([-1]).astype("int64")
+ loss = criterion(out_reshaped, y_reshaped)
+
+ pred = nn.functional.softmax(out, axis=1).argmax(axis=1)
+ iou_meter.add(pred, y)
+ loss_sum += float(loss.numpy())
+ n_batches += 1
+
+ if (i + 1) % display_step == 0:
+ miou, acc = iou_meter.get_miou_acc()
+ mode = "train" if optimizer is not None else "val"
+ print(
+ f"{mode} step {i+1}/{len(loader)} loss:{loss_sum/n_batches:.4f} acc:{acc:.3f} mIoU:{miou:.3f}"
+ )
+
+ miou, acc = iou_meter.get_miou_acc()
+ return loss_sum / max(1, n_batches), float(acc), float(miou)
+
+
+def main(config):
+ set_seed(config.rdm_seed)
+ if config.device == "gpu" and paddle.is_compiled_with_cuda():
+ paddle.set_device("gpu")
+ else:
+ paddle.set_device("cpu")
+ config.device = "cpu"
+
+ prepare_output(config)
+ save_conf(config)
+
+ folds = [config.fold] if config.fold is not None else [1, 2, 3, 4, 5]
+ trainlog = {}
+
+ for fold in folds:
+ print(f"===== Fold {fold} =====")
+ # Data
+ dt_train = PASTIS_Dataset(
+ folder=config.dataset_folder,
+ norm=True,
+ target="semantic",
+ folds=[f for f in range(1, 6) if f != fold],
+ cache=config.cache,
+ )
+ dt_val = PASTIS_Dataset(
+ folder=config.dataset_folder,
+ norm=True,
+ target="semantic",
+ folds=[fold],
+ cache=config.cache,
+ )
+
+ collate_fn = lambda x: utils.pad_collate(x, pad_value=config.pad_value)
+ train_loader = paddle.io.DataLoader(
+ dt_train,
+ batch_size=config.batch_size,
+ shuffle=True,
+ num_workers=config.num_workers,
+ collate_fn=collate_fn,
+ )
+ val_loader = paddle.io.DataLoader(
+ dt_val,
+ batch_size=config.batch_size,
+ shuffle=False,
+ num_workers=config.num_workers,
+ collate_fn=collate_fn,
+ )
+
+ # Model
+ model = model_utils.get_model(config, mode="semantic")
+ print(model)
+ print("TOTAL TRAINABLE PARAMETERS :", model_utils.get_ntrainparams(model))
+
+ # Optimizer, grad clip
+ grad_clip = None
+ if config.grad_clip and config.grad_clip > 0:
+ grad_clip = paddle.nn.ClipGradByGlobalNorm(config.grad_clip)
+ optimizer = paddle.optimizer.Adam(
+ learning_rate=config.lr, parameters=model.parameters(), grad_clip=grad_clip
+ )
+ # LR scheduler
+ if config.lr_decay == "step":
+ sched = paddle.optimizer.lr.StepDecay(
+ learning_rate=config.lr,
+ step_size=config.lr_step_size,
+ gamma=config.lr_gamma,
+ )
+ optimizer.set_lr_scheduler(sched)
+ elif config.lr_decay == "cosine":
+ sched = paddle.optimizer.lr.CosineAnnealingDecay(
+ learning_rate=config.lr, T_max=config.epochs
+ )
+ optimizer.set_lr_scheduler(sched)
+ else:
+ sched = None
+
+ criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_index)
+
+ # Train
+ best_miou = -1.0
+ epochs_no_improve = 0
+
+ for epoch in range(1, config.epochs + 1):
+ t0 = time.time()
+ model.train()
+ tr_loss, tr_acc, tr_miou = iterate(
+ model,
+ train_loader,
+ criterion,
+ device=config.device,
+ optimizer=optimizer,
+ display_step=config.display_step,
+ )
+ model.eval()
+ va_loss, va_acc, va_miou = iterate(
+ model,
+ val_loader,
+ criterion,
+ device=config.device,
+ optimizer=None,
+ display_step=config.display_step,
+ )
+ dt = time.time() - t0
+
+ # Scheduler step (if using)
+ if sched is not None:
+ sched.step()
+
+ print(
+ f"Epoch {epoch}/{config.epochs} "
+ f"train: loss {tr_loss:.4f} acc {tr_acc:.3f} miou {tr_miou:.3f} | "
+ f"val: loss {va_loss:.4f} acc {va_acc:.3f} miou {va_miou:.3f} "
+ f"({dt/60:.1f} min)"
+ )
+
+ # Save best
+ fold_dir = os.path.join(config.res_dir, f"Fold_{fold}")
+ os.makedirs(fold_dir, exist_ok=True)
+ ckpt_path = os.path.join(
+ fold_dir, f"model_epoch_{epoch}_miou_{va_miou:.3f}.pdparams"
+ )
+ paddle.save(model.state_dict(), ckpt_path)
+
+ if va_miou > best_miou:
+ best_miou = va_miou
+ # also update alias
+ alias = os.path.join(fold_dir, "model.pdparams")
+ paddle.save(model.state_dict(), alias)
+ epochs_no_improve = 0
+ else:
+ epochs_no_improve += 1
+
+ # Log
+ trainlog[f"fold{fold}_epoch{epoch}"] = {
+ "train_loss": tr_loss,
+ "train_acc": tr_acc,
+ "train_miou": tr_miou,
+ "val_loss": va_loss,
+ "val_acc": va_acc,
+ "val_miou": va_miou,
+ "best_miou": best_miou,
+ "lr": float(optimizer.get_lr())
+ if hasattr(optimizer, "get_lr")
+ else config.lr,
+ "time_min": dt / 60.0,
+ }
+ save_trainlog(trainlog, config)
+
+ # Early stopping
+ if config.early_stopping and epochs_no_improve >= config.early_stopping:
+ print(f"Early stopping triggered (patience={config.early_stopping}).")
+ break
+
+ print(f"[Fold {fold}] best mIoU = {best_miou:.3f}")
+
+ print("Training done.")
+
+
+if __name__ == "__main__":
+ config = parser.parse_args()
+ # device
+ if config.device not in ["gpu", "cpu"]:
+ config.device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
+ # Print configuration
+ pprint.pprint(vars(config))
+ main(config)
diff --git a/mkdocs.yml b/mkdocs.yml
index 5260fef71..5302a2082 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -117,6 +117,7 @@ nav:
- FuXi: zh/examples/fuxi.md
- UNetFormer: zh/examples/unetformer.md
- WGAN_GP: zh/examples/wgan_gp.md
+ - UTAE: zh/examples/UTAE.md
- 化学科学(AI for Chemistry):
- SMC Reac: zh/examples/smc_reac.md
- Moflow: zh/examples/moflow.md