Improving Physics-Informed Neural Networks through Adaptive Loss Balancing

In this article, we review the basics of PINNs, explore the issue of imbalanced losses, and show how the balancing scheme ReLoBRaLo (Relative Loss Balancing with Random Lookbacks) [1], proposed by Michael Kraus and myself, can significantly boost the training process. Plus, experience the technique in action with two accompanying notebooks solving real-world PDEs:
If you clicked on this article, it is probably because you already have quite a good understanding of what Physics-Informed Neural Networks (PINN) [2] are. Maybe you have found some tutorials online and implemented PINNs on well-known benchmarks like the Burgers or Helmholtz equation. The idea of harnessing the power of Neural Networks to solve complex partial differential equations (PDEs) is certainly an appealing one. But as many of us have painfully discovered, the reality of working with PINNs can also be quite a frustrating process. If you went ahead and tried applying these tools on a PDE that you encountered in your own research, one that may not yet be well-documented in the literature, then it is very likely that the vanilla PINNs may not have performed as well as you had hoped. Worse, that they even converged more slowly than established approaches like the Finite Elements Method (FEM)!
Bear with me, I have been there, in fact, basically every time that I tried applying PINNs to a new problem. Despite the progress that has been made in the five years since their proposal, and the decades of research on using differentiable Neural Networks to solve differential equations [3], there is still no easy, plug-and-play version of PINNs that can be seamlessly transferred to any type of problem.
You see, PINNs make use of differential equations in their loss function by taking multiple higher-order derivatives of the output with respect to the input. These derivatives are then used to construct the residual and boundary conditions that should be approximated. This means that each partial differential equation fundamentally changes the PINN's training procedure. Not only may it become necessary to adapt the architecture, such as adding or removing layers and nodes, but other, more intricate hyperparameters may have a crucial impact on the modelling capabilities, many of which are pecularities of PINNs and can not be found in the literature of classical Neural Networks. These may include the choice of activation functions, the sampling procedure on the physical domain, or, very treacherous, the choice of units of measurements in the differential equation.
While I can unfortunately not give you all the ingredients necessary to make your PINNs work, I can most certainly tell you what the fundamental steps are, without which your endeavour will most likely be fruitless. But before I reveal these crucial tools, allow me to provide some context to better understand my arguments. Let us take a step back and open a parenthesis, if you will. (
Benchmark PDEs
For the sake of illustration, let us introduce the Helmholtz and Kirchhoff plate bending equations. But before you start feeling overwhelmed, let me assure you that understanding the intricacies of these PDEs is not necessary for following the rest of this article. If you want to skip this section, just know that the Helmholtz PDE is a second-order PDE with zeroth order (Dirichlet) boundary conditions, and the Kirchhoff plate bending equation is a fourth-order PDE with boundary conditions on the zeroth and second-order derivatives.
This equation is a fourth-order partial differential equation (PDE) that describes the deformation of a plate under load. The unknown function u in the equation represents the vertical displacement (in meters, for example) from the initial state of the plate at a given point (x, y). The load applied on the plate is represented by the function p(x, y). The constant D in the equation encapsulates various properties of the plate such as its thickness, modulus of elasticity, and density.
So, Kirchhoff's equation states that the fourth-order derivative of the deformation is equal to the load applied on the plate divided by a constant factor. Fairly straightforward, right?
Of course every experienced PDE-tian knows that a governing equation, as elegant as it may seem, is nothing but a meaningless abstraction without the proper boundary conditions. After all, there are an infinite number of equations that could fulfill it.
So let us also introduce the boundary conditions:
where W and H define the the plate's width and height, respectively. The first row of the boundary conditions shows the zeroth (Dirichlet) and states that the edges of the plate are not allowed to bend. The second row shows the second-order derivatives, which enforce the moments on the edges to be zero. This can be illustrated with an edge of a plate that is supported below by a beam (hence zero 0th order derivative) and squeezed by another beam by above (resulting in zero moments).

The Helmholtz Equation: Modeling Waves in a Medium
The Helmholtz equation is a partial differential equation that describes the propagation of waves in a medium. It is a second-order equation and named after the German physicist Hermann von Helmholtz.
where k is the wave number and u(x, y) the unknown function to be found. For this problem, we will use the zeroth-order Dirichlet boundary conditions on all four edges of the domain:

PINN Loss Function
If I lost you anywhere during the definition of the Kirchhoff or Helmholtz functions, do not worry. It took me over half a year, and countless explanations from patient civil engineers, before being able to ruminate these formulas to you.
The real key is understanding how to translate these equations into a loss function that can be used to train our PINN, here for the Helmholtz equation:
import tensorflow as tf
import tensorflow.experimental.numpy.isclose
TOL = 1e-5
def compute_loss(self, x, y, u, dudxx, dudyy, eval=False):
"""
Computes the Physics-informed loss for Helmholtz's PDE.
Parameters
----------
x : tf.Tensor of shape (batch_size, 1)
x coordinate of the points in the current batch
y : tf.Tensor of shape (batch_size, 1)
y coordinate of the points in the current batch
u : tf.Tensor of shape (batch_size, 1)
predictions made by our PINN (dim 0)
dudxx : tf.Tensor of shape (batch_size, 1)
second-order derivative of the predictions w.r.t. x
dudyy : tf.Tensor of shape (batch_size, 1)
second-order derivative of the predictions w.r.t. y
"""
# governing equation loss
L_f = (dudxx + dudyy + self.k**2 * u -
(-np.pi**2 - (4 * np.pi)**2 + self.k**2) * tf.math.sin(np.pi * x) * tf.math.sin(4 * np.pi * y))**2
# determine which points are on the boundaries of the domain
# if a point is on either of the boundaries, its value is 1 and 0 otherwise
x_lower = tf.cast(isclose(x, -1, rtol=0., atol=EPS), dtype=tf.float32)
x_upper = tf.cast(isclose(x, 1, rtol=0., atol=EPS), dtype=tf.float32)
y_lower = tf.cast(isclose(y, -1, rtol=0., atol=EPS), dtype=tf.float32)
y_upper = tf.cast(isclose(y, 1, rtol=0., atol=EPS), dtype=tf.float32)
# compute 0th order boundary condition loss
L_b = ((x_lower + x_upper + y_lower + y_upper) * u)**2
if eval:
L_u = (tf.math.sin(np.pi*x) * tf.math.sin(4*np.pi*y) - u)**2
return L_f, L_b, L_u
return L_f, L_b
You can find the full code in the notebooks implementing ReLoBRaLo for the Helmholtz and the Kirchhoff PDEs.
Multi-Objective Optimisation
As we have already established, the final loss function for our Helmholtz PDE will consist of two, the Kirchhoff PDE of three objectives:
- Helmholtz: the loss for the governing equation L_f and the loss for the 0th order boundary condition L_b0.
- Kirchhoff: in addition to L_f and L_b0, Kirchhoff also has a term for the second-order boundary condition L_b2.
Therefore, these losses fall into the category of Multi-Objective Optimisation (MOO), as is the case for most applications involving PINNs.
The way the several objectives are aggregated into a single loss is usually done through linear scalarisation:
where the lambdas are scaling factors for controlling each term's contribution towards the total loss. But why are they necessary?
The Issue of imbalanced Gradients
After this detour for gathering the necessary context, we can finally close the open parenthesis ) and continue exploring why the units of measurements in the PDE have an influence on the convergence of PINNs. You see, the several objectives in our loss function – L_f, L_b0, and L_b2 – each have different units of measurement. L_b0 for Kirchhoff may be measured in meters, while L_b2 is measured in Nm, and the load on the plate is measured in MN per square meter. This creates a significant disparity in the magnitude of each term, leading to a computation of gradients that heavily favours the terms with the highest magnitude. The same is true for Helmholtz and any other PDE.
Let us have a look at what this means in our example with the Helmholtz equation.

Notice how the governing equation loss L_f is several orders of magnitude larger than the losses for the boundary conditions at the beginning of training and, as a consequence, how the value of L_b starts off by actually INCREASING. This discrepancy in magnitude can lead to a PINN that prioritizes L_f over L_b, ultimately converging towards a solution that satisfies the governing equation but neglects the crucial boundary conditions. This effect can be observed in the plot by the fact that the validation loss L_u follows the same pattern as the boundary loss L_b, suggesting that the validation performance is closely related to the performance on the boundaries.
What about the Kirchhoff PDE?

In the case of Kirchhoff, the inverse holds true. Here, the boundary conditions converge much more rapidly, while the governing equation makes little progress. The most likely explanation is that the governing equation involves fourth-order derivatives and is therefore a particularly hard objective to optimise for. This shows that the causes for imbalanced losses are not limited to differences in magnitudes between the terms. They range from the choice of activation function to the complexity of the function being approximated by each term.
Imbalanced gradients are by no means limited to the Helmholtz or Kirchhoff PDEs alone. Many studies have documented this issue in various PINN applications [4]. The key takeaway here is that, in order to arrive at accurate solutions, it is essential to strike a balance between all the objectives in the loss function.
Adaptive Loss balancing Schemes
To mitigate the issue of imbalanced losses and gradients, one can resort to the scaling factors lambda in the linear scalarisation of the Multi-Objective Optimisation introduced earlier. Selecting larger values of lambda for terms with smaller magnitudes or harder objectives can help evening the contributions to the final gradient, and thus make sure that all terms are appropriately approximated. However, doing this by hand is a tedious task, requiring many iterations and thus a lot of resources in terms of time and compute.
This is why researchers have proposed loss balancing schemes, such as Gradnorm [5], SoftAdapt [6] or Learning Rate Annealing [4].
Relative Loss Balancing with Random Lookbacks (ReLoBRaLo)
In this article, we will focus on a scheme called Relative Loss Balancing with Random Lookbacks (ReLoBRaLo) [1], which is a combination of the aforementioned methods.
The goal of ReLoBRaLo is to ensure that each term in the loss function makes the same amount of progress over time, relative to its value at the start of training. For example, if L_f improves by 50% since the beginning of training, we want the other terms to improve at about the same rate and achieve a reduction of 50%. However, if there is a term that consistently improves at a slower rate, ReLoBRaLo incrementally increases the scaling lambda of this term, thus increasing its contribution to the gradient calculation.
Let us say that we have n loss terms L_i and let us denote the function L_i(t) to be the value of this term at training iteration t. One way that we can measure its progress since the start of training is by dividing the value at the current iteration L_i(t) by the value at the beginning of training, L_i(0):
The greater the progress since the beginning of training was, the smaller the result of this operation will be. Observe how this is exactly what we are looking for: our scheme should attribute high scalings to terms that improved slowly, and small scalings to terms that improved fast – and all of that should happen independent of the absolute values of the terms. Therefore, we can use L_i(t) / L_i(0) for calculating the scalings of the terms in the loss function.
While this is the key component of ReLoBRaLo, it contains a number of additional extension that have been found to further improve the performance. However, for the sake of readability of this article, I leave it to the interested reader to have a look at the paper and learn more about the methods used and their motivation.
But does it work? Well let us have a look at the loss evolution on the Helmholtz PDE, but this time by using ReLoBRsLo for balancing the contributions of the terms to the total loss:

While the loss for the governing equation L_f does not make as much progress anymore (it converged at around -3.8 in the previous plot), the boundary conditions L_b and in consequence the validation loss L_u receive much more weight. The final validation loss against the analytical solution yields a 65% improvement over the unscaled training run. Let us have a look at the scaling values that ReLoBRaLo computed:

The same goes for Kirchhoff:


Again, ReLoBRaLo improved the error against the analytical solution by over an order of magnitude. It is also worth noting that this balancing scheme adds almost no computational overhead (cf. the paper). It is this effectiveness and efficiency that earned ReLoBRaLo its way into Nvidia's Modulus framework for Physics-Informed Deep Learning.
But the real question is: can you use ReLoBRaLo in your own projects? The answer is a resounding yes! As it happens, the scheme can be neatly wrapped into a keras loss that can either be added to your keras model through model.compile(), or, in case you defined your own custom training loop, by explicitly calling it at each iteration.
You can find the full code in the notebooks implementing ReLoBRaLo for the Helmholtz and the Kirchhoff PDEs.
import tensorflow as tf
class ReLoBRaLoLoss(tf.keras.losses.Loss):
"""
Class for the ReLoBRaLo Loss.
This class extends the keras Loss class to have dynamic weighting for each term.
"""
def __init__(self, pde:HelmholtzPDE, alpha:float=0.999, temperature:float=0.1, rho:float=0.99,
name='ReLoBRaLoLoss', **kwargs):
"""
Parameters
----------
pde : PDE
An instance of a PDE class containing the PDE-specific `compute_loss` function.
alpha, optional : float
Controls the exponential weight decay rate.
Value between 0 and 1. The smaller, the more stochasticity.
0 means no historical information is transmitted to the next iteration.
1 means only first calculation is retained. Defaults to 0.999.
temperature, optional : float
Softmax temperature coefficient. Controlls the "sharpness" of the softmax operation.
Defaults to 0.1.
rho, optional : float
Probability of the Bernoulli random variable controlling the frequency of random lookbacks.
Value berween 0 and 1. The smaller, the fewer lookbacks happen.
0 means lambdas are always calculated w.r.t. the initial loss values.
1 means lambdas are always calculated w.r.t. the loss values in the previous training iteration.
Defaults to 0.99.
"""
super().__init__(name=name, **kwargs)
self.pde = pde
self.alpha = alpha
self.temperature = temperature
self.rho = rho
self.call_count = tf.Variable(0, trainable=False, dtype=tf.int16)
self.lambdas = [tf.Variable(1., trainable=False) for _ in range(pde.num_terms)]
self.last_losses = [tf.Variable(1., trainable=False) for _ in range(pde.num_terms)]
self.init_losses = [tf.Variable(1., trainable=False) for _ in range(pde.num_terms)]
def call(self, xy, preds):
x, y = xy[:, :1], xy[:, 1:]
# obtain the unscaled values for each term
losses = [tf.reduce_mean(loss) for loss in self.pde.compute_loss(x, y, preds)]
# in first iteration (self.call_count == 0), drop lambda_hat and use init lambdas, i.e. lambda = 1
# i.e. alpha = 1 and rho = 1
# in second iteration (self.call_count == 1), drop init lambdas and use only lambda_hat
# i.e. alpha = 0 and rho = 1
# afterwards, default procedure (see paper)
# i.e. alpha = self.alpha and rho = Bernoully random variable with p = self.rho
alpha = tf.cond(tf.equal(self.call_count, 0),
lambda: 1.,
lambda: tf.cond(tf.equal(self.call_count, 1),
lambda: 0.,
lambda: self.alpha))
rho = tf.cond(tf.equal(self.call_count, 0),
lambda: 1.,
lambda: tf.cond(tf.equal(self.call_count, 1),
lambda: 1.,
lambda: tf.cast(tf.random.uniform(shape=()) < self.rho, dtype=tf.float32)))
# compute new lambdas w.r.t. the losses in the previous iteration
lambdas_hat = [losses[i] / (self.last_losses[i] * self.temperature + EPS) for i in range(len(losses))]
lambdas_hat = tf.nn.softmax(lambdas_hat - tf.reduce_max(lambdas_hat)) * tf.cast(len(losses), dtype=tf.float32)
# compute new lambdas w.r.t. the losses in the first iteration
init_lambdas_hat = [losses[i] / (self.init_losses[i] * self.temperature + EPS) for i in range(len(losses))]
init_lambdas_hat = tf.nn.softmax(init_lambdas_hat - tf.reduce_max(init_lambdas_hat)) * tf.cast(len(losses), dtype=tf.float32)
# use rho for deciding, whether a random lookback should be performed
new_lambdas = [(rho * alpha * self.lambdas[i] + (1 - rho) * alpha * init_lambdas_hat[i] + (1 - alpha) * lambdas_hat[i]) for i in range(len(losses))]
self.lambdas = [var.assign(tf.stop_gradient(lam)) for var, lam in zip(self.lambdas, new_lambdas)]
# compute weighted loss
loss = tf.reduce_sum([lam * loss for lam, loss in zip(self.lambdas, losses)])
# store current losses in self.last_losses to be accessed in the next iteration
self.last_losses = [var.assign(tf.stop_gradient(loss)) for var, loss in zip(self.last_losses, losses)]
# in first iteration, store losses in self.init_losses to be accessed in next iterations
first_iteration = tf.cast(self.call_count < 1, dtype=tf.float32)
self.init_losses = [var.assign(tf.stop_gradient(loss * first_iteration + var * (1 - first_iteration))) for var, loss in zip(self.init_losses, losses)]
self.call_count.assign_add(1)
return loss
Thank you a lot for reading until the end of this article! If you found this article helpful and would like to use ReLoBRaLo or the notebooks in your own work, please use this citation. You can find more information about me on rabischof.ch and my colleague on mkrausai.com.
[1] Rafael Bischof and Michael Kraus. Multi-objective loss balancing for physics-informed deep learning. arXiv preprint arXiv:2110.09813, 2021.
[2] M. Raissi, P. Perdikaris, and G. E. Karniadakis, Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations, Journal of Computational Physics 378 (2019), 686–707.
[3] H. Lee and I. S. Kang, Neural algorithm for solving differential equations, J. Comput. Phys. 91 (1990), no. 1, 110–131
[4] Wang, S., Teng, Y., and Perdikaris, P. Understanding and mitigating gradient pathologies in physics-informed neural networks. arXiv e-prints (Jan. 2020), arXiv:2001.04536.
[5] Chen, Z., Badrinarayanan, V., Lee, C.-Y., and Rabinovich, A. GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks. arXiv e-prints (Nov. 2017), arXiv:1711.02257.
[6] Heydari, A. A., Thompson, C. A., and Mehmood, A. SoftAdapt: Techniques for Adaptive Loss Weighting of Neural Networks with Multi-Part Loss Functions. arXiv e-prints (Dec. 2019), arXiv:1912.12355.