# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional, Tuple
import torch as th
import torch.nn.functional as thf
from drtk.interpolate import interpolate
from drtk.utils import index, load_torch_ops
load_torch_ops("drtk.edge_grad_ext")
[docs]
@th.compiler.disable
def edge_grad_estimator(
v_pix: th.Tensor,
vi: th.Tensor,
bary_img: th.Tensor,
img: th.Tensor,
index_img: th.Tensor,
v_pix_img_hook: Optional[Callable[[th.Tensor], None]] = None,
) -> th.Tensor:
"""Makes the rasterized image ``img`` differentiable at visibility discontinuities
and backpropagates the gradients to ``v_pix``.
This function takes a rasterized image ``img`` that is assumed to be differentiable at
continuous regions but not at discontinuities. In some cases, ``img`` may not be differentiable
at all. For example, if the image is a rendered segmentation mask, it remains constant at
continuous regions, making it non-differentiable. However, ``edge_grad_estimator`` can still
compute gradients at the discontinuities with respect to ``v_pix``.
The arguments ``bary_img`` and ``index_img`` must correspond exactly to the rasterized image
``img``. Each pixel in ``img`` should correspond to a fragment originated prom primitive
specified by ``index_img`` and it should have barycentric coordinates specified by
``bary_img``. This means that with a small change to ``v_pix``, the pixels in ``img`` should
change accordingly. A frequent mistake that violates this condition is applying a mask
to the rendered image to exclude unwanted regions, which leads to erroneous gradients.
The function returns the ``img`` unchanged but with added differentiability at the
discontinuities. Note that it is not necessary for the input ``img`` to require gradients,
but the returned ``img`` will always require gradients.
Args:
v_pix (Tensor): Pixel-space vertex coordinates, preserving the original camera-space
Z-values. Shape: :math:`(N, V, 3)`.
vi (Tensor): Face vertex index list tensor. Shape: :math:`(V, 3)`.
bary_img (Tensor): 3D barycentric coordinate image tensor. Shape: :math:`(N, 3, H, W)`.
img (Tensor): The rendered image. Shape: :math:`(N, C, H, W)`.
index_img (Tensor): Index image tensor. Shape: :math:`(N, H, W)`.
v_pix_img_hook (Optional[Callable[[th.Tensor], None]]): An optional backward hook that will
be registered to ``v_pix_img``. Useful for examining the generated image space. Default
is None.
Returns:
Tensor: Returns the input ``img`` unchanged. However, the returned image now has added
differentiability at visibility discontinuities. This returned image should be used for
computing losses
Note:
It is crucial not to spatially modify the rasterized image before passing it to
`edge_grad_estimator`. That stems from the requirement that ``bary_img`` and ``index_img``
must correspond exactly to the rasterized image ``img``. That means that the location of all
discontinuities is controlled by ``v_pix`` and can be modified by modifing ``v_pix``.
Operations that are allowed, as long as they are differentiable, include:
- Pixel-wise MLP
- Color mapping
- Color correction, gamma correction
- Anything that would be indistinguishable from processing fragments independently
before their values get assigned to pixels of ``img``
Operations that **must be avoided** before `edge_grad_estimator` include:
- Gaussian blur
- Warping or deformation
- Masking, cropping, or introducing holes
There is however, no issue with appling them after `edge_grad_estimator`.
If the operation is highly non-linear, it is recommended to perform it before calling
:func:`edge_grad_estimator`.
All sorts of clipping and clamping (e.g., `x.clamp(min=0.0, max=1.0)`) must also be done
before invoking this function.
Usage Example::
import torch.nn.functional as thf
from drtk import transform, rasterize, render, interpolate, edge_grad_estimator
...
v_pix = transform(v, tex, campos, camrot, focal, princpt)
index_img = rasterize(v_pix, vi, width=512, height=512)
_, bary_img = render(v_pix, vi, index_img)
vt_img = interpolate(vt, vti, index_img, bary_img)
img = thf.grid_sample(
tex,
vt_img.permute(0, 2, 3, 1),
mode="bilinear",
padding_mode="border",
align_corners=False
)
mask = (index_img != -1)[:, None, :, :]
img = img * mask
img = edge_grad_estimator(
v_pix=v_pix,
vi=vi,
bary_img=bary_img,
img=img,
index_img=index_img
)
optim.zero_grad()
image_loss = loss_func(img, img_gt)
image_loss.backward()
optim.step()
"""
# TODO: avoid call to interpolate, use backward kernel of interpolate directly
# Doing so will make `edge_grad_estimator` zero-overhead in forward pass
# At the moment, value of `v_pix_img` is ignored, and only passed to
# edge_grad_estimator so that backward kernel can be called with the computed gradient.
v_pix_img = interpolate(v_pix, vi, index_img, bary_img.detach())
img = th.ops.edge_grad_ext.edge_grad_estimator(v_pix, v_pix_img, vi, img, index_img)
if v_pix_img_hook is not None:
v_pix_img.register_hook(v_pix_img_hook)
return img
[docs]
def edge_grad_estimator_ref(
v_pix: th.Tensor,
vi: th.Tensor,
bary_img: th.Tensor,
img: th.Tensor,
index_img: th.Tensor,
v_pix_img_hook: Optional[Callable[[th.Tensor], None]] = None,
) -> th.Tensor:
"""
Python reference implementation for
:func:`drtk.edge_grad_estimator`.
"""
# could use v_pix_img output from DRTK, but bary_img needs to be detached.
v_pix_img = interpolate(v_pix, vi, index_img, bary_img.detach())
# pyre-fixme[16]: `EdgeGradEstimatorFunction` has no attribute `apply`.
img = EdgeGradEstimatorFunction.apply(v_pix, v_pix_img, vi, img, index_img)
if v_pix_img_hook is not None:
v_pix_img.register_hook(v_pix_img_hook)
return img
class EdgeGradEstimatorFunction(th.autograd.Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx,
v_pix: th.Tensor,
v_pix_img: th.Tensor,
vi: th.Tensor,
img: th.Tensor,
index_img: th.Tensor,
) -> th.Tensor:
ctx.save_for_backward(v_pix, img, index_img, vi)
return img
@staticmethod
# pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently.
def backward(ctx, grad_output: th.Tensor) -> Tuple[
Optional[th.Tensor],
Optional[th.Tensor],
Optional[th.Tensor],
Optional[th.Tensor],
Optional[th.Tensor],
]:
# early exit in case geometry is not optimized.
if not ctx.needs_input_grad[1]:
return None, None, None, grad_output, None
v_pix, img, index_img, vi = ctx.saved_tensors
x_grad = img[:, :, :, 1:] - img[:, :, :, :-1]
y_grad = img[:, :, 1:, :] - img[:, :, :-1, :]
l_index = index_img[:, None, :, :-1]
r_index = index_img[:, None, :, 1:]
t_index = index_img[:, None, :-1, :]
b_index = index_img[:, None, 1:, :]
x_mask = r_index != l_index
y_mask = b_index != t_index
x_both_triangles = (r_index != -1) & (l_index != -1)
y_both_triangles = (b_index != -1) & (t_index != -1)
iimg_clamped = index_img.clamp(min=0).long()
# compute barycentric coordinates
b = v_pix.shape[0]
vi_img = index(vi, iimg_clamped, 0).long()
p0 = th.cat(
[index(v_pix[i], vi_img[i, ..., 0].data, 0)[None, ...] for i in range(b)],
dim=0,
)
p1 = th.cat(
[index(v_pix[i], vi_img[i, ..., 1].data, 0)[None, ...] for i in range(b)],
dim=0,
)
p2 = th.cat(
[index(v_pix[i], vi_img[i, ..., 2].data, 0)[None, ...] for i in range(b)],
dim=0,
)
v10 = p1 - p0
v02 = p0 - p2
n = th.cross(v02, v10)
px, py = th.meshgrid(
th.arange(img.shape[-2], device=v_pix.device),
th.arange(img.shape[-1], device=v_pix.device),
)
def epsclamp(x: th.Tensor) -> th.Tensor:
return th.where(x < 0, x.clamp(max=-1e-8), x.clamp(min=1e-8))
# pyre-fixme[53]: Captured variable `n` is not annotated.
# pyre-fixme[53]: Captured variable `p0` is not annotated.
# pyre-fixme[53]: Captured variable `px` is not annotated.
# pyre-fixme[53]: Captured variable `py` is not annotated.
# pyre-fixme[53]: Captured variable `v02` is not annotated.
# pyre-fixme[53]: Captured variable `v10` is not annotated.
def check_if_point_inside_triangle(offset_x: int, offset_y: int) -> th.Tensor:
_px = px + offset_x
_py = py + offset_y
vp0p = th.stack([p0[..., 0] - _px, p0[..., 1] - _py], dim=-1) / epsclamp(
n[..., 2:3]
)
bary_1 = v02[..., 0] * -vp0p[..., 1] + v02[..., 1] * vp0p[..., 0]
bary_2 = v10[..., 0] * -vp0p[..., 1] + v10[..., 1] * vp0p[..., 0]
return ((bary_1 > 0) & (bary_2 > 0) & ((bary_1 + bary_2) < 1))[:, None]
left_pnt_inside_right_triangle = (
check_if_point_inside_triangle(-1, 0)[..., :, 1:]
& x_mask
& x_both_triangles
)
right_pnt_inside_left_triangle = (
check_if_point_inside_triangle(1, 0)[..., :, :-1]
& x_mask
& x_both_triangles
)
down_pnt_inside_up_triangle = (
check_if_point_inside_triangle(0, 1)[..., :-1, :]
& y_mask
& y_both_triangles
)
up_pnt_inside_down_triangle = (
check_if_point_inside_triangle(0, -1)[..., 1:, :]
& y_mask
& y_both_triangles
)
horizontal_intersection = (
right_pnt_inside_left_triangle & left_pnt_inside_right_triangle
)
vertical_intersection = (
down_pnt_inside_up_triangle & up_pnt_inside_down_triangle
)
left_hangs_over_right = left_pnt_inside_right_triangle & (
~right_pnt_inside_left_triangle
)
right_hangs_over_left = right_pnt_inside_left_triangle & (
~left_pnt_inside_right_triangle
)
up_hangs_over_down = up_pnt_inside_down_triangle & (
~down_pnt_inside_up_triangle
)
down_hangs_over_up = down_pnt_inside_up_triangle & (
~up_pnt_inside_down_triangle
)
x_grad *= x_mask
y_grad *= y_mask
grad_output_x = 0.5 * (grad_output[:, :, :, 1:] + grad_output[:, :, :, :-1])
grad_output_y = 0.5 * (grad_output[:, :, 1:, :] + grad_output[:, :, :-1, :])
x_grad = (x_grad * grad_output_x).sum(dim=1)
y_grad = (y_grad * grad_output_y).sum(dim=1)
x_grad_no_int = x_grad * (~horizontal_intersection[:, 0])
y_grad_no_int = y_grad * (~vertical_intersection[:, 0])
x_grad_spread = th.zeros(
*x_grad_no_int.shape[:1],
x_grad_no_int.shape[1],
y_grad_no_int.shape[2],
dtype=x_grad_no_int.dtype,
device=x_grad_no_int.device,
)
x_grad_spread[:, :, :-1] = x_grad_no_int * (~right_hangs_over_left[:, 0])
x_grad_spread[:, :, 1:] += x_grad_no_int * (~left_hangs_over_right[:, 0])
y_grad_spread = th.zeros_like(x_grad_spread)
y_grad_spread[:, :-1, :] = y_grad_no_int * (~down_hangs_over_up[:, 0])
y_grad_spread[:, 1:, :] += y_grad_no_int * (~up_hangs_over_down[:, 0])
# Intersections. Compute border sliding gradients
#################################################
z_grad_spread = th.zeros_like(x_grad_spread)
x_grad_int = x_grad * horizontal_intersection[:, 0]
y_grad_int = y_grad * vertical_intersection[:, 0]
n = thf.normalize(n, dim=-1)
n = n.permute(0, 3, 1, 2)
n_left = n[..., :, :-1]
n_right = n[..., :, 1:]
n_up = n[..., :-1, :]
n_down = n[..., 1:, :]
def get_dp_db(v_varying: th.Tensor, v_fixed: th.Tensor) -> th.Tensor:
"""
Computes derivative of the point position with respect to edge displacement
See drtk/src/edge_grad/edge_grad_kernel.cu
Please refer to the paper "Rasterized Edge Gradients: Handling Discontinuities Differentiably"
for details.
"""
v_varying = thf.normalize(v_varying, dim=1)
v_fixed = thf.normalize(v_fixed, dim=1)
b = th.stack([-v_fixed[:, 1], v_fixed[:, 0]], dim=1)
b_dot_varying = (b * v_varying).sum(dim=1, keepdim=True)
return b[:, 0:1] / epsclamp(b_dot_varying) * v_varying
# We compute partial derivatives by fixing one triangle and moving the
# other, and then vice versa.
# Left triangle moves, right fixed
dp_dbx = get_dp_db(n_left[:, [0, 2]], -n_right[:, [0, 2]])
x_grad_spread[:, :, :-1] += x_grad_int * dp_dbx[:, 0]
z_grad_spread[:, :, :-1] += x_grad_int * dp_dbx[:, 1]
# Left triangle fixed, right moves
dp_dbx = get_dp_db(n_right[:, [0, 2]], n_left[:, [0, 2]])
x_grad_spread[:, :, 1:] += x_grad_int * dp_dbx[:, 0]
z_grad_spread[:, :, 1:] += x_grad_int * dp_dbx[:, 1]
# Upper triangle moves, lower fixed
dp_dby = get_dp_db(n_up[:, [1, 2]], -n_down[:, [1, 2]])
y_grad_spread[:, :-1, :] += y_grad_int * dp_dby[:, 0]
z_grad_spread[:, :-1, :] += y_grad_int * dp_dby[:, 1]
# Lower triangle moves, upper fixed
dp_dby = get_dp_db(n_down[:, [1, 2]], n_up[:, [1, 2]])
y_grad_spread[:, 1:, :] += y_grad_int * dp_dby[:, 0]
z_grad_spread[:, 1:, :] += y_grad_int * dp_dby[:, 1]
m = index_img == -1
x_grad_spread[m] = 0.0
y_grad_spread[m] = 0.0
grad_v_pix = -th.stack([x_grad_spread, y_grad_spread, z_grad_spread], dim=3)
return None, grad_v_pix, None, grad_output, None