import os
import netCDF4
import numpy as np
from typing import Literal
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras, float32 #type: ignore
[docs]
class Att_Res_UNet():
def __init__(self, list_predictors, list_targets, patch_dim, batch_size, n_filters, activation, kernel_initializer, batch_norm, pooling_type, dropout):
self.list_predictors = list_predictors
self.list_targets = list_targets
self.patch_dim = patch_dim
self.batch_size = batch_size
self.n_filters = n_filters
self.activation = activation
self.kernel_initializer = kernel_initializer
self.batch_norm = batch_norm
self.pooling_type = pooling_type
self.dropout = dropout
self.n_predictors = len(list_predictors)
self.n_targets = len(list_targets)
[docs]
def repeat_elem(self, tensor, rep):
return keras.layers.Lambda(lambda x, repnum: keras.backend.repeat_elements(x, repnum, axis = 3), arguments = {'repnum': rep})(tensor)
[docs]
def gating_signal(self, x, n_filters, batch_norm = False):
x = keras.layers.Conv2D(n_filters, (1,1), padding = "same")(x)
if batch_norm == True:
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation("relu")(x)
return(x)
[docs]
def attention_block(self, x, g, inter_shape):
shape_x = keras.backend.int_shape(x)
shape_g = keras.backend.int_shape(g)
theta_x = keras.layers.Conv2D(inter_shape, kernel_size = (2,2), strides = (2,2), padding = "same")(x)
shape_theta_x = keras.backend.int_shape(theta_x)
phi_g = keras.layers.Conv2D(inter_shape, kernel_size = (1,1), padding = "same")(g)
upsample_g = keras.layers.Conv2DTranspose(inter_shape, (3,3),
strides = (shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
padding = "same")(phi_g)
concat_xg = keras.layers.add([upsample_g, theta_x])
act_xg = keras.layers.Activation("relu")(concat_xg)
psi = keras.layers.Conv2D(1, (1,1), padding = "same")(act_xg)
sigmoid_xg = keras.layers.Activation("sigmoid")(psi)
shape_sigmoid = keras.backend.int_shape(sigmoid_xg)
upsample_psi = keras.layers.UpSampling2D(size = (shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)
upsample_psi = self.repeat_elem(upsample_psi, shape_x[3])
y = keras.layers.multiply([upsample_psi, x])
result = keras.layers.Conv2D(shape_x[3], (1,1), padding = "same")(y)
result_bn = keras.layers.BatchNormalization()(result)
return(result_bn)
[docs]
def residual_conv_block(self, x, n_filters, padding: Literal["valid", "same"] = "same", kernel_size=(3,3)):
conv = keras.layers.Conv2D(n_filters, kernel_size = kernel_size, padding = padding, kernel_initializer = self.kernel_initializer)(x)
if self.batch_norm == True:
conv = keras.layers.BatchNormalization(axis = 3)(conv)
conv = keras.layers.Activation(self.activation)(conv)
conv = keras.layers.Conv2D(n_filters, kernel_size = kernel_size, padding = padding, kernel_initializer = self.kernel_initializer)(conv)
if self.batch_norm == True:
conv = keras.layers.BatchNormalization(axis = 3)(conv)
shortcut = keras.layers.Conv2D(n_filters, kernel_size = (1,1), padding = padding)(x)
if self.batch_norm == True:
shortcut = keras.layers.BatchNormalization(axis = 3)(shortcut)
res_path = keras.layers.add([shortcut, conv])
res_path = keras.layers.Activation(self.activation)(res_path)
return(res_path)
[docs]
def downsample_block(self, x, n_filters, pool_size = (2,2), kernel_size = (3,3), strides = 2):
f = self.residual_conv_block(x, n_filters, kernel_size=kernel_size)
if self.pooling_type == "Max":
p = keras.layers.MaxPool2D(pool_size = pool_size, strides = strides)(f)
elif self.pooling_type == "Average":
p = keras.layers.AveragePooling2D(pool_size = pool_size, strides = strides)(f)
else:
raise ValueError("Invalid pooling type. Must be 'Max' or 'Average'.")
p = keras.layers.Dropout(self.dropout)(p)
return(f, p)
[docs]
def upsample_block(self, x, conv_features, n_filters, kernel_size = (3,3), strides = 2, padding = "same"):
gating = self.gating_signal(x, n_filters)
att = self.attention_block(conv_features, gating, n_filters)
up_att = keras.layers.UpSampling2D(size = (2, 2), data_format = "channels_last")(x)
up_att = keras.layers.concatenate([up_att, att], axis = 3)
up_conv = self.residual_conv_block(up_att, n_filters,kernel_size=kernel_size)
return(up_conv)
[docs]
def make_unet_model(self):
inputs = keras.layers.Input(shape = (*self.patch_dim, self.n_predictors))
# Encoder (downsample)
f1, p1 = self.downsample_block(inputs, self.n_filters[0], kernel_size=(7,7))
f2, p2 = self.downsample_block(p1, self.n_filters[1])
f3, p3 = self.downsample_block(p2, self.n_filters[2])
f4, p4 = self.downsample_block(p3, self.n_filters[3])
f5, p5 = self.downsample_block(p4, self.n_filters[4])
# Bottleneck
u5 = self.residual_conv_block(p5, self.n_filters[5])
# Decoder (upsample)
u4 = self.upsample_block(u5, f5, self.n_filters[4])
u3 = self.upsample_block(u4, f4, self.n_filters[3])
u2 = self.upsample_block(u3, f3, self.n_filters[2])
u1 = self.upsample_block(u2, f2, self.n_filters[1])
u0 = self.upsample_block(u1, f1, self.n_filters[0])
# outputs
SICerror = keras.layers.Conv2D(self.n_targets, (1, 1), padding = "same", activation = "linear", dtype = float32, name = "psi")(u0)
unet_model = keras.Model(inputs, SICerror, name = "U-Net")
return(unet_model)
[docs]
def featname2tuple(self, feature_name):
parts = feature_name.rsplit('_', 1) # Split from the right, max 1 split
varname = parts[0]
channel = int(parts[1])
return varname, channel
[docs]
class Data_generator(keras.utils.Sequence):
def __init__(self, nrun, startrun, shuffle, batch_size, dim, path_data, list_predictors, list_targets, sampleperrun = 100):
self.nrun = nrun
self.startrun = startrun
self.sampleperrun = sampleperrun
self.n = nrun*(sampleperrun-1)
self.shuffle = shuffle
self.batch_size = batch_size
self.path_data = path_data
self.dim = dim
self.shuffle = shuffle
self.list_predictors = list_predictors
self.list_targets = list_targets
self.npredictors = len(self.list_predictors)
self.ntargets = len(self.list_targets)
self.indexes = np.arange(self.n)
if self.shuffle == True:
rng = np.random.default_rng()
rng.shuffle(self.indexes)
def __len__(self): # Number of batches per epoch
return self.n // self.batch_size
[docs]
def index2rs(self, index): #From index of the sample to the number of the run and sample
r = index // self.sampleperrun
s = index % self.sampleperrun
return r,s
def __getitem__(self, index): # Generate one batch of data
indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
list_r_s = [self.index2rs(k) for k in indexes]
X, y = self.__data_generation(list_r_s)
return (X, y)
[docs]
def on_epoch_end(self): # Updates indexes after each epoch
self.indexes = np.arange(self.n)
if self.shuffle == True:
rng = np.random.default_rng()
rng.shuffle(self.indexes)
def __data_generation(self, list_r_s): # Generates data containing batch_size samples
X = np.full((self.batch_size, *self.dim, self.npredictors), np.nan)
y = np.full((self.batch_size, *self.dim, self.ntargets), np.nan)
for i, (r,s) in enumerate(list_r_s):
fileIDX = os.path.join(self.path_data,f'{r+1+self.startrun:04d}',f'{s:03d}.nc')
fileIDy = os.path.join(self.path_data,f'{r+1+self.startrun:04d}',f'{s+1:03d}.nc')
ncx = netCDF4.Dataset(fileIDX, "r")
ncy = netCDF4.Dataset(fileIDy, "r")
for k in range(self.npredictors):
varname, channel = self.featname2tuple(self.list_predictors[k])
X[i,...,k] = ncx.variables[varname][0,channel,:,:]
for k in range(self.ntargets):
varname, channel = self.featname2tuple(self.list_targets[k])
y[i,...,k] = ncy.variables[varname][0,channel,:,:]
ncx.close()
ncy.close()
return (X, y)