#!/net/python/bin/python from Header import * import os import numpy as num from mathfuncs import splev,splrep import pylab as p import matplotlib.axes3d as p3 import types from Function1D import * class Function2D: """ A class to hold an equally sampled 2D function coordinates are x,y,z at each x coordinate, there is an array """ def __init__(self,filename=NUGENT_SPECTRUM): #filename is the path to a file that has two columns of data: # Column 1 is the x value, column 2 is the y value self.getfromfile(filename) def copy(self): selfcopy = self.__class__() selfcopy.x = self.x.copy() selfcopy.y = self.y.copy() selfcopy.z = self.z.copy() return selfcopy def getfromfile(self,filename): """ imports a function from filename filename is assumed to be two columns, with x in column 1, y in column 2 unused lines should start with # """ self.x = [] self.y = [] self.z = [] if not os.path.exists(filename): print 'Function2D: ' + filename + ' is not a valid file' return s found=False for line in open(filename,'r'): line=map(float,line.split()) if len(self.x) == 0 or line[0] != self.x[-1]: #a new x has been found self.x.append(line[0]) self.z.append([]) if len(self.x) == 1: #y has not been filled out yet self.y.append(line[1]) self.z[-1].append(line[2]) self.x = num.array(self.x) self.y = num.array(self.y) self.z = num.array(self.z) def inrange_x(self,val): if len(self.y) == 0: return 0 return (val <= self.x[-1] and val >= self.x[0]) def inrange_y(self,val): if len(self.x) == 0: return 0 return (val <= self.y[-1] and val >= self.y[0]) def contract(self,xmin,xmax,ymin,ymax): selfcopy = self.copy() xi_min = num.where(self.x >= xmin)[0][0] xi_max = num.where(self.x <= xmax)[0][-1] yi_min = num.where(self.y >= ymin)[0][0] yi_max = num.where(self.y <= ymax)[0][-1] selfcopy.x = self.x[xi_min:xi_max+1] selfcopy.y = self.y[yi_min:yi_max+1] selfcopy.z = self.z[xi_min:xi_max+1,yi_min:yi_max+1] return selfcopy def getslice_x(self,xval): """ returns a Function1D object corresponding to the slice at x=xval """ F = Function1D() F.x = num.array(self.y) if xval in self.x: xi = self.x.searchsorted(xval) F.y = num.array(self.z[xi,:]) else: self_interp = self.resample([xval],self.y) F.y = num.array(self_interp.z[0,:]) return F def getslice_y(self,yval): """ returns a Function1D object corresponding to the slice at y=yval """ F = Function1D() F.x = num.array(self.x) if yval in self.y: yi = self.y.searchsorted(yval) F.y = num.array(self.z[:,yi]) else: self_interp = self.resample(self.x,[yval]) F.y = num.array(self_interp.z[:,0]) return F def value(self,xval,yval): if (xval in self.x) and (yval in self.y): xi = self.x.searchsorted(xval) yi = self.y.searchsorted(yval) return self.z[xi,yi] else: return self.resample([xval],[yval]).z[0,0] def resample(self,xrange,yrange): newFunction = self.__class__() newFunction.x = num.array([val for val in xrange if self.inrange_x(val)]) newFunction.y = num.array([val for val in yrange if self.inrange_y(val)]) newFunction.z = num.zeros([len(newFunction.x),len(newFunction.y)]) z = num.zeros([len(self.x),len(newFunction.y)]) #first get all the y values lined up for i in range(len(self.x)): sp = splrep(self.y,self.z[i,:]) new_z = splev(newFunction.y,sp) z[i,:] = new_z #now get all the x values lined up for i in range( len(newFunction.y) ): sp = splrep(self.x,z[:,i]) new_z = splev(newFunction.x,sp) newFunction.z[:,i] = new_z return newFunction def __mul__(self,other): if type(other) in (types.IntType,types.FloatType,types.LongType): selfcopy = self.copy() selfcopy.z *= other return selfcopy elif type(other) == types.FunctionType: try: z = other(0,0) except: print "Function2D:__mul__: multiplied function must take (x,y)" return None selfcopy = self.copy() for i in range(len(selfcopy.x)): for j in range(len(selfcopy.y)): selfcopy.z[i,j] *= other(selfcopy.x[i],selfcopy.y[i]) return selfcopy else: print "Function2D:__mul__:unrecognized multiplier" return None def __rmul__(self,other): return self.__mul__(other) def __imul__(self,other): if type(other) in (types.IntType,types.FloatType,types.LongType): self.z *= other return self elif type(other) == types.FunctionType: try: z = other(0,0) except: print "Function2D:__mul__: multiplied function must take (x,y)" return None for i in range(len(self.x)): for j in range(len(self.y)): self.z[i,j] *= other(self.x[i],self.y[i]) return self else: print "Function2D:__mul__:unrecognized multiplier" return None def __div__(self,other): selfcopy = self.copy() selfcopy.z /= other return selfcopy def __idiv__(self,other): self.z /= other return self def __add__(self,other): if type(other) in (types.IntType,types.FloatType,types.LongType): selfcopy = self.copy() selfcopy.z += other return selfcopy elif type(other) == type(self): xmin = max( self.x[0], other.x[0] ) xmax = min( self.x[-1], other.x[-1] ) ymin = max( self.y[0], other.y[0] ) ymax = min( self.y[-1], other.y[-1] ) self_new = self.contract(xmin,xmax,ymin,ymax) other_new = self.resample(self_new.x,self_new.y) self_new.z += other_new.z return self_new def __iadd__(self,other): if type(other) in (types.IntType,types.FloatType,types.LongType): self.z += other return self elif type(other) == type(self): xmin = max( self.x[0], other.x[0] ) xmax = min( self.x[-1], other.x[-1] ) ymin = max( self.y[0], other.y[0] ) ymax = min( self.y[-1], other.y[-1] ) self = self.contract(xmin,xmax,ymin,ymax) other_new = self.resample(self_new.x,self_new.y) self.z += other_new.z return self def __radd__(self,other): return self.__add__(other) def plotsurface(self,xmin = -10,xmax = 45,ymin = 2500,ymax = 10000,logplot=0): """ Creates a surface plot of the 2D function """ xi_min = num.searchsorted(self.x,xmin) xi_max = num.searchsorted(self.x,xmax) yi_min = num.searchsorted(self.y,ymin) yi_max = num.searchsorted(self.y,ymax) x_plot = self.x[xi_min:xi_max] y_plot = [] z_plot = [] for zarray in self.z[xi_min:xi_max]: y_plot.append( self.y[yi_min:yi_max] ) if logplot: z_plot.append( num.log10(zarray[yi_min:yi_max]) ) else: z_plot.append( zarray[yi_min:yi_max] ) y_plot = num.array(y_plot) z_plot = num.array(z_plot) #make same shape x_plot = num.multiply.outer(x_plot, num.ones(z_plot.shape[1])) fig = p.figure() ax = p3.Axes3D(fig) # works nice ax.plot_wireframe(x_plot,y_plot,z_plot) ax.contour3D(x_plot,y_plot,z_plot, 100, cmap=p.cm.jet) ax.set_xlabel('Time') ax.set_ylabel('Wavelength') ax.set_zlabel('Flux') # elev, az ax.view_init(40, -10) # works for log flux if __name__ == '__main__': F1 = Function2D() F1.plotsurface() p.show()