import torch
import numpy as np
import gradoptics as optics
import matplotlib.pyplot as plt
Matplotlib created a temporary config/cache directory at /var/folders/tg/2_q32n3x5q75j4ytd6n3kmvh0000gp/T/matplotlib-kr6wlc7_ because the default path (/Users/stanford/.matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.

Creating a custom optical element

Creating an interface between two mediums

For now, we will use the world coordinate system and assume that the surface is parallel to the yz plane

from gradoptics.optics.base_optics import BaseOptics

class MediumInterface(BaseOptics):
    """
    Flat surface that defines an interface between two mediums with different indices of refraction
    """
    
    def __init__(self, xpos, n_ext, n_medium):
        """
        :param xpos: Position of the interface along the optical axis (:obj:`float`)
        :param n_ext: Index of refraction of the outside (:obj:`float`)
        :param n_medium: Index of refraction of the medium (:obj:`float`)
        """
        self.xpos = xpos
        self.n_ext = n_ext
        self.n_medium = n_medium

Every optical element inherits from the abstract class BaseOptics, and should implement the functions get_ray_intersection, intersect and plot.

def get_ray_intersection(self, incident_rays):
    
    t = (self.xpos - incident_rays.origins[:, 0]) / incident_rays.directions[:, 0]
    return t
def intersect(self, incident_rays, t):
    
    directions = incident_rays.directions
        
    # Normal of the interface
    normal = torch.zeros(directions.shape, device=directions.device, dtype=directions.dtype)
    normal[:, 0] = 1
    # Check for each ray if it is coming from the left
    condition = directions[:, 0] > 0
    # Flip the normal for the rays coming from the right
    normal[~condition] *= -1

    mu = torch.zeros(directions.shape[0], device=incident_rays.origins.device)
    mu[condition] = self.n_medium / self.n_ext # The rays coming from the left come from the medium
    mu[~condition] = self.n_ext / self.n_medium # The rays coming from the right come from the outside

    # See https://physics.stackexchange.com/questions/435512/snells-law-in-vector-form
    tmp = 1 - mu ** 2 * (1 - (optics.optics.vector.dot_product(normal, directions)) ** 2)
    
    mask = tmp >= 0 # Killing the rays for which there is total internal reflection
    
    c = optics.optics.vector.dot_product(normal[mask], directions[mask])
    direction_refracted_rays = torch.sqrt(tmp[mask]).unsqueeze(1) * normal[mask] + mu[mask].unsqueeze(1) * (
                directions - c.unsqueeze(1) * normal)

    return (optics.Rays(incident_rays(t)[mask], direction_refracted_rays, 
                            luminosities=incident_rays.luminosities[mask] if incident_rays.luminosities is not None else None,
                            device=incident_rays.device), 
            mask)
def plot(self, ax):
    
    Y = np.arange(-.5, .5, 0.01)
    Z = np.arange(-.5, .5, 0.01)
    Y, Z = np.meshgrid(Y, Z)
    X = np.zeros_like(Y) + self.xpos

    # Plot the surface.
    surf = ax.plot_surface(X, Y, Z)

Putting it all together.

from gradoptics.optics.base_optics import BaseOptics

class MediumInterface(BaseOptics):
    """
    Flat surface that defines an interface between two mediums with different indices of refraction
    """
    
    def __init__(self, xpos, n_ext, n_medium):
        """
        :param xpos: Position of the interface along the optical axis (:obj:`float`)
        :param n_ext: Index of refraction of the outside (:obj:`float`)
        :param n_medium: Index of refraction of the medium (:obj:`float`)
        """
        self.xpos = xpos
        self.n_ext = n_ext
        self.n_medium = n_medium
        
    def get_ray_intersection(self, incident_rays):

        t = (self.xpos - incident_rays.origins[:, 0]) / incident_rays.directions[:, 0]
        return t
    
    def intersect(self, incident_rays, t):

        directions = incident_rays.directions
        
        # Normal of the medium
        normal = torch.zeros(directions.shape, device=directions.device, dtype=directions.dtype)
        normal[:, 0] = 1
        
        # Check for each ray if it is coming from the left
        condition = directions[:, 0] > 0
        # Flip the normal for the rays coming from the right
        normal[~condition] *= -1
        
        mu = torch.zeros(directions.shape[0], device=incident_rays.origins.device)
        mu[condition] = self.n_medium / self.n_ext # The rays coming from the left come from the medium
        mu[~condition] = self.n_ext / self.n_medium # The rays coming from the right come from the outside
        
        # See https://physics.stackexchange.com/questions/435512/snells-law-in-vector-form
        tmp = 1 - mu ** 2 * (1 - (optics.optics.vector.dot_product(normal, directions)) ** 2)
        mask = tmp >= 0 # Killing the rays for which there is total internal reflection
        c = optics.optics.vector.dot_product(normal[mask], directions[mask])
        direction_refracted_rays = torch.sqrt(tmp[mask]).unsqueeze(1) * normal[mask] + mu[mask].unsqueeze(1) * (
                    directions - c.unsqueeze(1) * normal)
        
        return (optics.Rays(incident_rays(t)[mask], direction_refracted_rays, 
                            luminosities=incident_rays.luminosities[mask] if incident_rays.luminosities is not None else None,
                            device=incident_rays.device), 
                mask)
    
    def plot(self, ax):
        
        Y = np.arange(-.5, .5, 0.01)
        Z = np.arange(-.5, .5, 0.01)
        Y, Z = np.meshgrid(Y, Z)
        X = np.zeros_like(Y) + self.xpos

        # Plot the surface.
        surf = ax.plot_surface(X, Y, Z)

Creating a scene with a MediumInterface.

interface = MediumInterface(.2, 1., 1.005)
# Creating a scene
f = 0.05
m = 0.15
lens = optics.PerfectLens(f=f, na=1 / 1.4, position=[0., 0., 0.], m=m)
sensor = optics.Sensor(position=(-f * (1 + m), 0, 0))
atom_cloud = optics.AtomCloud(n=int(1e6), f=2, position=[f * (1 + m) / m, 0., 0.], phi=0.1)
light_source = optics.LightSourceFromDistribution(atom_cloud)
scene = optics.Scene(light_source)
scene.add_object(lens)
scene.add_object(interface)
scene.add_object(sensor)
# Vizualizing the scene
fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
scene.plot(ax)
../_images/a0e898ce0fba97121e9207a7dbef3bfd8b134e4c32af001e03c841f17be11a5e.png

Producing different images with different indices of refraction.

for n_medium in [1., 1.005, 1.01]:
    
    c = MediumInterface(.2, 1., n_medium)
    
    # Creating a scene
    f = 0.05
    m = 0.15
    lens = optics.PerfectLens(f=f, na=1 / 1.4, position=[0., 0., 0.], m=m)
    sensor = optics.Sensor(position=(-f * (1 + m), 0, 0))
    atom_cloud = optics.AtomCloud(n=int(1e6), f=2, position=[f * (1 + m) / m, 0., 0.], phi=0.1)
    light_source = optics.LightSourceFromDistribution(atom_cloud)
    scene = optics.Scene(light_source)
    scene.add_object(lens)
    scene.add_object(c)
    scene.add_object(sensor)
    
    # Producing an image
    device = 'cpu'
    for batch in range(2):
        rays = light_source.sample_rays(10_000_000, device=device)
        
        # /!\ Setting max iteration to 3 because the rays will intersect 3 optical elements (interface -> lens -> sensor)
        optics.forward_ray_tracing(rays, scene, max_iterations=3)

    # Readout the sensor
    c = (4800, 4800)
    w = 60
    produced_image = sensor.readout(add_poisson_noise=False).data.cpu().numpy()
    plt.imshow(produced_image[c[0] - w : c[0] + w, c[1] - w : c[1] + w], cmap='Blues')
    plt.show()
/Users/stanford/Library/Python/3.8/lib/python/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
../_images/eca6e2bc5d62b59e03e0797f248aaa3efb659bdfb21e0a8f38881b96566016a1.png ../_images/72edcd7c8f562b9fb4eabe878f94dd5415548216b5b18f715eda974f16114653.png ../_images/106bbc00b02600c819ff0098c157cb37a81c415077aa632f77cce0eceec30276.png

Adding a transform

Minor modifications to the developed code can be incorporated in order to shift and orient the interface easily.

This can be done by adding a transform in the constructor and using it in the class functions to switch back and forth between the world space and the object space.

from gradoptics.optics.base_optics import BaseOptics

class MediumInterface(BaseOptics):
    """
    Flat surface that defines an interface between two mediums with different indices of refraction
    """
    
    def __init__(self, n_ext, n_medium, transform):
        """
        :param n_ext: Index of refraction of the outside (:obj:`float`)
        :param n_medium: Index of refraction of the medium (:obj:`float`)
        :param transform: :param transform: Transform to orient the interface (:py:class:`~gradoptics.transforms.base_transform.BaseTransform`)
        """  
        self.n_ext = n_ext
        self.n_medium = n_medium
        self.transform = transform
        
    def get_ray_intersection(self, incident_rays):
        
        incident_rays = self.transform.apply_inverse_transform(incident_rays)  # World space to object space

        # Now working with a local coordinate system at the center of the plane interface
        t = - incident_rays.origins[:, 0] / incident_rays.directions[:, 0]
        return t
    
    def intersect(self, incident_rays, t):
        
        incident_rays = self.transform.apply_inverse_transform(incident_rays)  # World space to object space

        directions = incident_rays.directions
        
        # Normal of the medium
        normal = torch.zeros(directions.shape, device=directions.device, dtype=directions.dtype)
        normal[:, 0] = 1
        
        # Check for each ray if it is coming from the left
        condition = directions[:, 0] > 0
        # Flip the normal for the rays coming from the right
        normal[~condition] *= -1
        
        mu = torch.zeros(directions.shape[0], device=incident_rays.origins.device)
        mu[condition] = self.n_medium / self.n_ext # The rays coming from the left come from the medium
        mu[~condition] = self.n_ext / self.n_medium # The rays coming from the right come from the outside
        
        # See https://physics.stackexchange.com/questions/435512/snells-law-in-vector-form
        tmp = 1 - mu ** 2 * (1 - (optics.optics.vector.dot_product(normal, directions)) ** 2)
        mask = tmp >= 0 # Killing the rays for which there is total internal reflection
        c = optics.optics.vector.dot_product(normal[mask], directions[mask])
        direction_refracted_rays = torch.sqrt(tmp[mask]).unsqueeze(1) * normal[mask] + mu[mask].unsqueeze(1) * (
                    directions - c.unsqueeze(1) * normal)
        
        refracted_rays =  optics.Rays(incident_rays(t)[mask], direction_refracted_rays, 
                                      luminosities=incident_rays.luminosities[mask] if incident_rays.luminosities is not None else None,
                                      device=incident_rays.device)
    
        return self.transform.apply_transform(refracted_rays), mask
    
    def plot(self, ax):
        
        Y = torch.arange(-.5, .5, 0.01)
        Z = torch.arange(-.5, .5, 0.01)
        Y, Z = torch.meshgrid(Y, Z)
        X = torch.zeros_like(Y)
        
        # coordinates to world space
        xyz = self.transform.apply_transform_(torch.cat((X.reshape(-1, 1), Y.reshape(-1, 1), Z.reshape(-1, 1)), dim=1))

        # Plot the surface.
        surf = ax.plot_surface(xyz[:, 0].reshape(X.shape).numpy(), 
                               xyz[:, 1].reshape(X.shape).numpy(),
                               xyz[:, 2].reshape(X.shape).numpy())
transform = optics.simple_transform.SimpleTransform(10, 40, 10, torch.tensor([1, 0, 0]))
c = MediumInterface(1., 1., transform)
fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
c.plot(ax)
../_images/66c5b3f8fe690aa490f29f28d409888f0f8ee9b2805a635602624758753ac591.png
for n_medium in [1., 1.005, 1.01]:
    
    transform = optics.simple_transform.SimpleTransform(0, 0, 0, torch.tensor([.2, 0, 0]))
    c = MediumInterface(1., n_medium, transform)
    
    # Creating a scene
    f = 0.05
    m = 0.15
    lens = optics.PerfectLens(f=f, na=1 / 1.4, position=[0., 0., 0.], m=m)
    sensor = optics.Sensor(position=(-f * (1 + m), 0, 0))
    atom_cloud = optics.AtomCloud(n=int(1e6), f=2, position=[f * (1 + m) / m, 0., 0.], phi=0.1)
    light_source = optics.LightSourceFromDistribution(atom_cloud)
    scene = optics.Scene(light_source)
    scene.add_object(lens)
    scene.add_object(c)
    scene.add_object(sensor)
    
    # Producing an image
    device = 'cpu'
    for batch in range(2):
        rays = light_source.sample_rays(10_000_000, device=device)
        optics.forward_ray_tracing(rays, scene, max_iterations=3)

    # Readout the sensor
    c = (4800, 4800)
    w = 60
    produced_image = sensor.readout(add_poisson_noise=False).data.cpu().numpy()
    plt.imshow(produced_image[c[0] - w : c[0] + w, c[1] - w : c[1] + w], cmap='Blues')
    plt.show()
../_images/a0ac0b736371854e20039625b6f7d87994d7a300c0d9b4ce8b791cc537387061.png ../_images/9c333cc0b296d289a6c20defc891eb15b41bd611f2d9da1db06f02bde2542bb7.png ../_images/0e9bf734caa499ab135008d067f934955ce7c7c681fffe5e98d12b1764f45852.png