Source code for spacepy.pybats.qotree

#!/usr/bin/env python

'''
A class for building QO trees for :class:`~spacepy.pybats.bats.Bats2d`
objects.
'''


[docs] class QTree(object): ''' Base class for Quad/Oct tree objects assuming cell-centered grid points in a rectangular non-regular layout. QTree works for square blocks only (e.g., blocks who have the same number of points in each dimension). The ``blocksize`` kwarg sets the size of the blocks. As BATS-R-US typically uses a block size of 8 (i.e., blocks are 8x8x8 points), the default value of ``blocksize`` is 8. '''
[docs] def __init__(self, grid, blocksize=8): ''' Build QO Tree for input grid. Grid should be a NxM numpy array where N is the number of dimensions and M is the number of points. ''' from numpy import sqrt, where, lexsort, inf # Set size of each block self.blocksize = blocksize (self.d, self.npoints) = grid.shape if self.d != 2: raise NotImplementedError("Sorry, QTrees are for 2D grids only.") self.tree = {} # Find limits of cell-centered grid. # Get values and locations of grid max/min minloc, xmin = grid[0, :].argmin(), grid[0, :].min() maxloc, xmax = grid[0, :].argmax(), grid[0, :].max() # Distance from min X value is grid size at that point. r = sqrt((grid[0, :]-xmin)**2. + (grid[1, :]-grid[1, :][minloc])**2.) ml = (where(r > 0, r, 10000)).min() # Actual min is not cell center, but cell boundary. xmin -= ml/2. # Repeat for Xmax. r = sqrt((grid[0, :]-xmax)**2. + (grid[1, :]-grid[1, :][maxloc])**2.) ml = (where(r > 0, r, 10000)).min() xmax += ml/2. # Repeat for Ymin/max. minloc, ymin = grid[1, :].argmin(), grid[1, :].min() maxloc, ymax = grid[1, :].argmax(), grid[1, :].max() r = sqrt((grid[1, :]-ymin)**2. + (grid[0, :]-grid[0, :][minloc])**2.) ml = (where(r > 0, r, 10000)).min() ymin -= ml/2. r = sqrt((grid[1, :]-ymax)**2. + (grid[0, :]-grid[0, :][maxloc])**2.) ml = (where(r > 0, r, 10000)).min() ymax += ml/2. # Some things all trees should know about themselves. self.nleafs = 0 self.aspect_ratio = (xmax-xmin)/(ymax-ymin) self.dx_min = inf # Minimum and maximum spacing self.dx_max = -1 # over all leafs in tree. # Use spatial range of grid to seed root of QTree. self[1] = [xmin, xmax, ymin, ymax] self.locs = lexsort((grid[0, :], grid[1, :])) self._spawn_kids(grid) self.nbranch = len(list(self.keys()))
def _spawn_kids(self, grid, i=1): ''' Internal recursive method for populating tree. ''' from numpy import sqrt, log2, meshgrid, mod, linspace # Start by limiting locations to within block limits: self[i].locs = self.locs[(grid[0, :][self.locs] > self[i].lim[0]) & (grid[0, :][self.locs] < self[i].lim[1]) & (grid[1, :][self.locs] > self[i].lim[2]) & (grid[1, :][self.locs] < self[i].lim[3])] self[i].npts = self[i].locs.size # Bats blocks are nPoints by nPoints, # blocks must have no less than nPts points. # Stop branching as soon as possible (e.g. combine blocks of like dx). # Here, npts=self.blocksize**2 * 2^a # 2^a is the number of complete blocks of size blocksize**2 within the # current group of points. If 'a' is non-integer, we include blocks # of different grid spacing. a = log2(self[i].npts/self.blocksize**2) # Grab some grid information to be used in evaluating current block: xnow = grid[0, :][self[i].locs] xmax, xmin = xnow.max(), xnow.min() if int(a) == a: # integer 'a' implies correct number of points to be a "leaf", but # more investigation required. # Approximate dx assuming a proper block. dx = (xmax-xmin) / (sqrt(self[i].npts)-1) # Count points along x=xmax and x=xmin. These are equal in Leafs. nxmin = xnow[xnow == xmin].size nxmax = xnow[xnow == xmax].size # Is block square? Is it integer multiple of other blocks? blksize = int(sqrt(self[i].npts)) issquare = (a == 0) or (nxmax == nxmin == blksize) # Check for uniform grid spacing (only if a "square" block): isuniform = False if issquare: temploc = self[i].locs.reshape((blksize, blksize)) xsquare = grid[0, :][temploc] ysquare = grid[1, :][temploc] dxblk = abs(xsquare[1:, :]) - abs(xsquare[:-1, :]) dyblk = abs(ysquare[:, 1:]) - abs(ysquare[:, :-1]) isuniform = dxblk.min() == dxblk.max() \ == dyblk.min() == dyblk.max() # Define leaf as area of constant dx (using approx above) # or npts=a*blocksize**2 where "a" is an integer. if issquare and isuniform: # An NxN block can be considered a "leaf" or stopping point # if above criteria are met. Leafs must "know" the # indices of the x,y points located inside of them as a # grid and also know the bounding coords of the grid cells. self[i].isLeaf = True a = int(sqrt(self[i].npts)) self[i].locs = self[i].locs.reshape((blksize, blksize)) self[i].dx = dx if dx > self.dx_max: self.dx_max = dx if dx < self.dx_min: self.dx_min = dx self[i].cells = meshgrid( linspace(self[i].lim[0], self[i].lim[1], blksize+1), linspace(self[i].lim[2], self[i].lim[3], blksize+1)) self.nleafs += 1 return elif (self[i].npts < self.blocksize**2): # If we do not reach a true leaf but have few points, # we likely have hit an interface surface. This is an # interface region: the space between two blocks of # different resolution. Points from **both** resoultions # are included in the output file. Create a leaf that uses the # smaller of the two and discards the rest. self[i].isLeaf = True # Interfaces are considered leafs # The number of points along each dimension should follow # this relationship at interface blocks: a = 2*sqrt(self[i].npts/5) if int(a) != a: raise ValueError( "Failure to handle interface " + "surface! Please report to Spacepy devs.") a = int(a) # Calculate dx based on this value: dx = (xmax-xmin) / (a-1) # Refine locs so that only multiples of dx are included: subloc = mod(xnow-grid[0, :][self[i].locs[0]], dx) == 0 self[i].locs = self[i].locs[subloc] self[i].locs = self[i].locs.reshape((a, a)) self[i].dx = dx if dx > self.dx_max: self.dx_max = dx if dx < self.dx_min: self.dx_min = dx self[i].cells = meshgrid( linspace(self[i].lim[0], self[i].lim[1], blksize+1), linspace(self[i].lim[2], self[i].lim[3], blksize+1)) self.nleafs += 1 return # If above criteria are not met, this block is # not a constant-resolution zone. # Subdivide section into four new ones (8 if oct tree) dx = (self[i].lim[1] - self[i].lim[0])/2.0 x = [self[i].lim[0], self[i].lim[0]+dx, self[i].lim[0]+dx, self[i].lim[0]] y = [self[i].lim[2], self[i].lim[2], self[i].lim[2]+dx, self[i].lim[2]+dx] for j, k in enumerate(range(self.ld(i), self.rd(i)+1)): self[k] = [x[j], x[j]+dx, y[j], y[j]+dx] self._spawn_kids(grid, k)
[docs] def find_leaf(self, x, y, i=1): ''' Recursively search for and return the index of the leaf that contains the input point x, y. ''' l = self[i].lim # if point is in this block... if (l[0] <= x <= l[1]) and (l[2] <= y <= l[3]): # ...and it's a leaf, return this block. if self[i].isLeaf: return i # ...and it's a branch, dig deeper. else: for j in range(self.ld(i), self.rd(i)+1): answer = self.find_leaf(x, y, i=j) if answer: return answer else: return False
def __getitem__(self, key): return self.tree[key] def __setitem__(self, key, value): self.tree[key] = Branch(value) def __contains__(self, key): return key in self.tree def __iter__(self): return iter(self.tree)
[docs] def keys(self): return list(self.tree.keys())
[docs] def mom(self, k): return (k+2**self.d-2)/2**self.d
[docs] def leftdaughter(self, k): return k*2**self.d - 2**self.d+2
[docs] def rightdaughter(self, k): return k*2**self.d + 1
# convenience: rd = rightdaughter ld = leftdaughter
[docs] def plot_res(self, ax, do_label=True, do_fill=True, tag_leafs=False, zlim=False, cmap='jet_r'): from matplotlib.pyplot import get_cmap from matplotlib import patheffects from matplotlib.colors import LogNorm from matplotlib.cm import ScalarMappable # Create a color map using either zlim as given or max/min resolution. vmin, vmax = self.dx_min, self.dx_max if vmin == vmax: vmin *= 0.9 vmax *= 1.1 cNorm = LogNorm(vmin=vmin, vmax=vmax, clip=True) cMap = ScalarMappable(cmap=get_cmap(cmap), norm=cNorm) dx_vals = {} for key in self.tree: if self[key].isLeaf: color = cMap.to_rgba(self[key].dx) dx_vals[self[key].dx] = 1.0 self[key].plot_res(ax, fc=color, do_fill=do_fill, label=key*tag_leafs) if do_label: ax.annotate('Resolution:', [1.02, 0.99], xycoords='axes fraction', color='k', size='medium') for i, key in enumerate(sorted(dx_vals.keys())): # dx_int = log2(key) if key < 1: label = '1/%i' % (key**-1) else: label = '%i' % key ax.annotate(f'{label} $R_{{E}}$', [1.02, 0.87-i*0.1], xycoords='axes fraction', color=cMap.to_rgba(key), size='x-large', path_effects=[patheffects.withStroke( linewidth=1, foreground='k')])
[docs] class Branch(object): ''' Base class for branches/leafs along a QO tree. '''
[docs] def __init__(self, lim): ''' lim should be a 4 element list of the dimensional boundaries of the branch. ''' self.isLeaf = False self.lim = lim
[docs] def plotbox(self, ax, lc='k', **kwargs): ''' Plot a box encompassing the branch lim onto axis 'ax'. ''' from matplotlib.collections import LineCollection from numpy import array l = self.lim segs = ( array([[l[0], l[2]], [l[1], l[2]]]), array([[l[0], l[3]], [l[1], l[3]]]), array([[l[0], l[2]], [l[0], l[3]]]), array([[l[1], l[2]], [l[1], l[3]]])) alpha = 1 # max(0, min(.8, -1/40.*l[2])) coll = LineCollection(segs, colors=lc, alpha=alpha, **kwargs) ax.add_collection(coll)
# ax.plot(l[0:2], [l[2],l[2]], **kwargs) # ax.plot(l[0:2], [l[3],l[3]], **kwargs) # ax.plot([l[0],l[0]], l[2:], **kwargs) # ax.plot([l[1],l[1]], l[2:], **kwargs)
[docs] def plot_res(self, ax, fc='gray', do_fill=True, label=False): if not self.isLeaf: return from matplotlib.patches import Polygon from numpy import array l = self.lim verts = array([ [l[0], l[2]], [l[1], l[2]], [l[1], l[3]], [l[0], l[3]]]) alpha = 1 # max(0, min(.8, -1/40.*l[2])) poly = Polygon(verts, closed=True, ec=None, fc=fc, fill=do_fill, lw=0.0, alpha=alpha) if label: x = l[0]+(l[1]-l[0])/2.0 y = l[2]+(l[3]-l[2])/2.0 ax.text(x, y, label) ax.add_patch(poly)