Skip to content
Snippets Groups Projects
Select Git revision
  • 931330eb4a8a19cb791d2b2606196e81c886ab38
  • master default protected
  • new_cbl_models
  • 0.99
4 results

graphics.py

Blame
  • graphics.py 17.91 KiB
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    import numpy as np
    from matplotlib import pyplot as p
    import matplotlib.colors as colors
    from ENDA import experiment
    
    is_sorted = lambda a: np.all(a[:-1] <= a[1:])
    
    def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
        #https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib
        new_cmap = colors.LinearSegmentedColormap.from_list(
            'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
            cmap(np.linspace(minval, maxval, n)))
        return new_cmap
    
    def plot_CBL_assimilation(exp,figtitle,which_cycle=0,ax=None):
    
        naturerun_color = 'black'
        ensmembers_color = 'lightskyblue'
        ensmean_color = 'royalblue'
        
        def ensplot(ax,x,x1,x2,z,colors=[ensmean_color,ensmembers_color],label=None):
            nens = x.shape[1]
            ax.plot(x[x1:x2].mean(axis=1),z,color=colors[0],zorder=5,label=label)
            for i in range(nens): ax.plot(x[x1:x2,i],z,color=colors[1],alpha=0.2,zorder=1)
    
        # Shorthand
        o = exp.observations[which_cycle]
        b = exp.da.backgrounds[which_cycle,:,:]
        a = exp.da.analyses[which_cycle,:,:]
        if exp.obs_coordinates.ndim == 1:
            ocoords = exp.obs_coordinates
        elif exp.obs_coordinates.ndim == 2:
            ocoords = exp.obs_coordinates[which_cycle]
        nz = exp.nr.nz
        zt = exp.nr.zt 
    
        if exp.nr.is_bwind:
            ii=3
        else:
            ii=1
    
        zmax = 2000
    
        if ax is None:
            fig = p.figure(151)
            fig.set_size_inches(9,3*ii)
            #
            # Potential temperature
            #
            ax1 = fig.add_subplot(ii,4,1)
            ax1.scatter(o,ocoords,marker="o",s=20,color=naturerun_color,zorder=10)
            ax1.set_title(r'$\theta$ (K), nature run')
            ax1.set_ylabel(r'$z$ (m)')
            ax1.set_ylim([0,zmax])
            #
            ax2 = fig.add_subplot(ii,4,2,sharey=ax1)
            ensplot(ax2,b,0,nz,zt)
            ax2.scatter(o,ocoords,marker="o",s=20,color=naturerun_color,zorder=10)
            ax2.set_title(r'$\theta$ (K), prior')
            #
            ax3 = fig.add_subplot(ii,4,3,sharey=ax1)
            ensplot(ax3,a,0,nz,zt)
            ax3.scatter(o,ocoords,marker="o",s=20,color=naturerun_color,zorder=10)
            ax3.set_title(r'$\theta$ (K), posterior')
            #
            ax4 = fig.add_subplot(ii,4,4,sharey=ax1)
            ax4.plot(a[:nz].mean(axis=1)-b[:nz].mean(axis=1),zt,color=ensmean_color)
            ax4.set_title(r'$\theta$ (K), mean increment',zorder=10)
            #
            # Fix axis labels
            #
            for ax in [ax2,ax3,ax4]:
                p.setp(ax.get_yticklabels(), visible=False)
            #
            if exp.nr.is_bwind:
                #
                # u wind component
                #
                ax5 = fig.add_subplot(ii,4,5)
                ax5.set_title(r'$u$ (m/s), nature run')
                ax5.set_ylabel(r'$z$ (m)')
                ax5.set_ylim([0,zmax])
                #
                ax6 = fig.add_subplot(ii,4,6,sharey=ax1)
                ensplot(ax6,b,nz,nz*2,zt)
                ax6.set_title(r'$u$ (m/s), prior')
                #
                ax7 = fig.add_subplot(ii,4,7,sharey=ax1)
                ensplot(ax7,a,nz,nz*2,zt)
                ax7.set_title(r'$u$ (m/s), posterior')
                #
                ax8 = fig.add_subplot(ii,4,8,sharey=ax1)
                ax8.plot(a[nz:nz*2].mean(axis=1)-b[nz:nz*2].mean(axis=1),zt,color=ensmean_color)
                ax8.set_title(r'$u$ (m/s), mean increment',zorder=10)
                #
                # v wind component
                #
                ax9 = fig.add_subplot(ii,4,9)
                ax9.set_title(r'$v$ (m/s), nature run')
                ax9.set_ylabel(r'$z$ (m)')
                ax9.set_ylim([0,zmax])
                #
                axa = fig.add_subplot(ii,4,10,sharey=ax1)
                ensplot(axa,b,nz*2,nz*3,zt)
                axa.set_title(r'$v$ (m/s), prior')
                #
                axb = fig.add_subplot(ii,4,11,sharey=ax1)
                ensplot(axb,a,nz*2,nz*3,zt)
                axb.set_title(r'$v$ (m/s), posterior')
                #
                axc = fig.add_subplot(ii,4,12,sharey=ax1)
                axc.plot(a[nz*2:nz*3].mean(axis=1)-b[nz*2:nz*3].mean(axis=1),zt,color=ensmean_color)
                axc.set_title(r'$v$ (m/s), mean increment',zorder=10)
                #
                # Fix axis labels
                #
                for ax in [ax6,ax7,ax8,axa,axb,axc]:
                    p.setp(ax.get_yticklabels(), visible=False)
            #
            fig.savefig(figtitle,format='png',dpi=300)
            p.close(fig)
        
        else:
    
            ensplot(ax,b,0,nz,zt,colors = ['red','salmon'],label='prior')
            ensplot(ax,a,0,nz,zt,colors = ['royalblue','lightskyblue'],label='posterior')
            ax.scatter(o,ocoords,marker="o",s=20,color=naturerun_color,zorder=10)
            #ax.set_title(r'$\theta$ (K), posterior')
            return ax
    
    def plot_CBL_PE(exp,figtitle,parameter_id=0,plot_spread=False,ax=None):
    
        if exp.nr.do_parameter_estimation:
    
            spread_color = 'lightskyblue'
            mean_color = 'royalblue'
            mean_color_2 = 'orange'
    
            if parameter_id==0:
                true_value = exp.nr.pfac
    
            pt = exp.nr.parameter_transform
            parameter_number = exp.nr.parameter_number
            par_tran = exp.da.backgrounds[:,-parameter_number:,:][:,parameter_id,:]
            par_phys = pt[parameter_id](par_tran,kind='inv')
            initpar_phys = pt[parameter_id](exp.da.initial_perturbed_parameters,kind='inv')
            t = exp.da.time/3600.
    
            plot_parameter_histogram = True
            if plot_parameter_histogram:
                p10beg_tran = np.percentile(par_tran[0,:],  10)
                p90beg_tran = np.percentile(par_tran[0,:],  90)
                p10end_tran = np.percentile(par_tran[-1,:], 10)
                p90end_tran = np.percentile(par_tran[-1,:], 90)
                p10beg_phys = np.percentile(par_phys[0,:],  10)
                p90beg_phys = np.percentile(par_phys[0,:],  90)
                p10end_phys = np.percentile(par_phys[-1,:], 10)
                p90end_phys = np.percentile(par_phys[-1,:], 90)
                #
                fig = p.figure(151)
                fig.set_size_inches(6,3)
                #
                ax1 = fig.add_subplot(1,2,1)
                ax1.hist(par_phys[0,:], bins=np.linspace(0,5,51), density=True, histtype='step', color='blue', label='initial')
                ax1.axvspan(p10beg_phys,p90beg_phys,color='blue',lw=0,alpha=0.1)
                ax1.hist(par_phys[-1,:], bins=np.linspace(0,5,51), density=True, histtype='step', color='red', label='final')
                ax1.axvspan(p10end_phys,p90end_phys,color='red',lw=0,alpha=0.1)
                ax1.axvline(x=exp.nr.parameter_true,color='black',dashes=[3,1])
                ax1.set_xlabel(r'$p$')
                ax1.set_ylabel('probability density')
                #
                ax2 = fig.add_subplot(1,2,2)
                ax2.hist(par_tran[0,:], bins=np.linspace(-10,10,51), density=True, histtype='step', color='blue', label='initial')
                ax2.axvspan(p10beg_tran,p90beg_tran,color='blue',lw=0,alpha=0.1)
                ax2.hist(par_tran[-1,:], bins=np.linspace(-10,10,51), density=True, histtype='step', color='red', label='final')
                ax2.axvspan(p10end_tran,p90end_tran,color='red',lw=0,alpha=0.1)
                ax2.axvline(x=pt[parameter_id](exp.nr.parameter_true,kind='dir'),color='black',dashes=[3,1])
                ax2.set_xlabel(r"$p\prime$")
                ax2.set_ylabel('probability density')
                #
                fig.tight_layout()
                fig.savefig('p_histogram_%s.png'%exp.label,format='png',dpi=300)
                p.close(fig)
    
            if ax is None:
                printfigure = True
            else:
                printfigure = False
    
            if printfigure:
                if plot_spread:
                    ii=2
                else:
                    ii=1
                fig = p.figure(151)
                fig.set_size_inches(4,3*ii)
                ax1 = fig.add_subplot(ii,1,1)
            else:
                ax1 = ax
    
            # Initial parameter distribution
            ax1.step([0,t[0]],np.median(initpar_phys,axis=1)*np.array([1,1]),color=mean_color)
            ax1.step([0,t[0]],np.mean(initpar_phys,axis=1)*np.array([1,1]),color=mean_color_2)
            bmin = np.percentile(initpar_phys, 10)
            bmax = np.percentile(initpar_phys, 90)
            ax1.fill_between([0,t[0]],\
                        y1=[bmin,bmin],\
                        y2=[bmax,bmax],\
                        color=spread_color,
                        edgecolor=None,
                        alpha=0.3)
            # Later times
            ax1.step(t,np.median(par_phys,axis=1),color=mean_color)
            ax1.step(t,np.mean(par_phys,axis=1),color=mean_color_2)
            for i in range(t.size-1):
                bmin = np.percentile(par_phys, 10, axis=1)[i+1]
                bmax = np.percentile(par_phys, 90, axis=1)[i+1]
                ax1.fill_between(t[i:i+2],\
                            y1=[bmin,bmin],\
                            y2=[bmax,bmax],\
                            color=spread_color,
                            edgecolor=None,
                            alpha=0.3)
            ax1.axhline(y=true_value,linestyle='--',color='black')
            ax1.set_xlim([0,t.max()])
    
            if printfigure and plot_spread:
                ax2 = fig.add_subplot(ii,1,2)
                ax2.plot(t,par_phys.std(axis=1),color=mean_color,label='physical')
                ax2.plot(t,par_tran.std(axis=1),color=mean_color_2,label='transformed', linestyle='--')
                ax2.legend()
                ax2.set_title('Parameter spread')
    
            if printfigure:
                fig.tight_layout()
                fig.savefig(figtitle,format='png',dpi=300)
                p.close(fig)
            else:
                return ax1
    
    def plot_CBL_identifiability(cbl,sigma_o,ax=None):
    
        cmap = p.get_cmap('RdBu_r')
        new_cmap = truncate_colormap(cmap, 0.5, 1.0)
        zmax = 2000
        ncont = 13
    
        # cbl must be an "experiment" instance
        # Read relevant dimensions
        zt = cbl.da.zt
        nz = cbl.da.nz
        nens = cbl.da.nens
        nobs = cbl.da.observations.shape[1]
        npar = cbl.nr.parameter_number
        ntimes = cbl.da.ncycles
        times = cbl.da.time
    
        # Expand times array to 2D (matching the number of observations)
        times = times[:,None]+np.zeros((ntimes,nobs))
    
        # Read model equivalents and state vector elements
        theta = cbl.da.model_equivalents  # ntim*nobs*nens
        pfac = cbl.da.backgrounds[:,-npar,:] # ntim*nx*nens
        pfac = cbl.nr.parameter_transform[0](pfac,kind='inv')
    
        # Read observation coordinates
        if cbl.da.obs_coordinates.ndim == 1:
            zobs = cbl.da.obs_coordinates
            zobs = zobs[None,:]+np.zeros(times.shape)
        elif cbl.da.obs_coordinates.ndim == 2:
            zobs = cbl.da.obs_coordinates
    
        # At each time, check if the coordinate array is sorted
        # if not, resort both coordinates and model equivalents
        for k in range(ntimes):
            if is_sorted(zobs[k,:]):
                pass
            else:
                indices = np.argsort(zobs[k,:])
                zobs[k,:] = zobs[k,indices]
                theta[k,:] = theta[k,indices]
                cbl.da.innov[k,:] = cbl.da.innov[k,indices] # ntim*nobs
                cbl.da.cov_yy[k,:] = cbl.da.cov_yy[k,indices] # ntim*nobs
                cbl.da.cov_pp[k,:] = cbl.da.cov_pp[k,indices] # ntim*nobs(*npar)
                cbl.da.cov_py[k,:] = cbl.da.cov_py[k,indices] # ntim*nobs(*npar)
                cbl.da.increments[k,:] = cbl.da.increments[k,indices] # ntim*nobs(*npar)
    
        if hasattr(cbl.da,'cov_py'):
            sigma2_yb = cbl.da.cov_yy # ntim*nobs
            sigma2_p = np.squeeze(cbl.da.cov_pp) # ntim*nobs(*npar)
            covariance = np.squeeze(cbl.da.cov_py) # ntim*nobs(*npar)
            correlation = covariance/(sigma2_yb*sigma2_p)**0.5
            innov = cbl.da.innov
            deltap = np.squeeze(cbl.da.increments) # ntim*nobs(*npar)
    
        # Shorthand for plots
        A = correlation.T
        B = ( innov*(sigma2_p*sigma2_yb)**0.5/(sigma2_yb+sigma_o**2) ).T
        C = deltap.T
        D = ( (sigma2_p/sigma2_yb)**0.5 ).T
        E = ( innov*sigma2_yb/(sigma2_yb+sigma_o**2) ).T
    
        # For plotting
        contours = np.linspace(cbl.nr.theta_0,cbl.nr.theta_0+cbl.nr.gamma*zmax,ncont)
    
        # Before plotting, transpose the 2D arrays of times and coordinates
        times = times.T
        zobs = zobs.T
    
        # Make plots
        if ax is None:
            pass
    
        elif type(ax) is p.Axes:
    
            c0 = ax.pcolormesh(times/3600.,zobs,A,vmin=-1,vmax=1,cmap='RdBu_r')
            ax.contour(times/3600.,zobs,theta.mean(axis=2).T,
                        contours,
                        colors='black',
                        linestyles='--',
                        linewidths=0.75)
            ax.set_ylim([0,zmax])
    
            return ax,c0
    
        elif type(ax) is list and len(ax)==2:
    
            c0 = ax[0].pcolormesh(times/3600.,zobs,D,cmap=new_cmap)
            ax[0].contour(times/3600.,zobs,theta.mean(axis=2).T,
                        contours,
                        colors='black',
                        linestyles='--',
                        linewidths=0.75)
            ax[0].set_ylim([0,zmax])
    
            c1 = ax[1].pcolormesh(times/3600.,zobs,E,norm=colors.CenteredNorm(),cmap='RdBu_r')
            ax[1].contour(times/3600.,zobs,theta.mean(axis=2).T,
                        contours,
                        colors='black',
                        linestyles='--',
                        linewidths=0.75)
            ax[1].set_ylim([0,zmax])
    
            return ax,c0,c1
    
        elif type(ax) is list and len(ax)==3:
    
            c0 = ax[0].pcolormesh(times/3600.,zobs,A,vmin=-1,vmax=1,cmap='RdBu_r')
            ax[0].contour(times/3600.,zobs,theta.mean(axis=2).T,
                        contours,
                        colors='black',
                        linestyles='--',
                        linewidths=0.75)
            ax[0].set_ylim([0,zmax])
    
            c1 = ax[1].pcolormesh(times/3600.,zobs,B,norm=colors.CenteredNorm(),cmap='RdBu_r')
            ax[1].contour(times/3600.,zobs,theta.mean(axis=2).T,
                        contours,
                        colors='black',
                        linestyles='--',
                        linewidths=0.75)
            ax[1].set_ylim([0,zmax])
    
            c2 = ax[2].pcolormesh(times/3600.,zobs,C,norm=colors.CenteredNorm(),cmap='RdBu_r')
            ax[2].contour(times/3600.,zobs,theta.mean(axis=2).T,
                        contours,
                        colors='black',
                        linestyles='--',
                        linewidths=0.75)
            ax[2].set_ylim([0,zmax])
    
        make_parameter_state_scatterplots = False
        if make_parameter_state_scatterplots:
            fig, [[ax1, ax2],[ax3, ax4]] = p.subplots(2,2,constrained_layout=True)
            fig.set_size_inches(4,4)
            ax3.scatter(theta[:,0,0] , pfac) #ll
            ax4.scatter(theta[:,0,-1], pfac) #lr
            ax1.scatter(theta[:,-1,0], pfac) #ul
            ax2.scatter(theta[:,-1,-1],pfac) #ur
            fig.savefig('p_scatter_%s.png'%cbl.label,format='png',dpi=300)
            p.close(fig)
    
        return ax,c0,c1,c2
    
    def plot_p(p_factors,theta_profiles,zt,figtitle,ax=None):
    
        zoverh = np.linspace(0,1,101)
    
        if ax is None:
            fig = p.figure(151)
            fig.set_size_inches(6,3)
            #
            ax1 = fig.add_subplot(1,2,1)
            for pfac in p_factors:
                Koverkws = zoverh*(1-zoverh)**pfac
                ax1.plot(Koverkws,zoverh,label='$p=%4.1f$'%pfac)
            ax1.set_xlabel('$K/(\kappa w_s z_i)$')
            ax1.set_ylabel('$z/z_i$')
            #
            ax2 = fig.add_subplot(1,2,2)
            for i in range(len(p_factors)):
                ax2.plot(theta_profiles[i],zt,label='$p=%4.1f$'%p_factors[i])
            ax2.set_xlabel(r'$\theta$ (K)')
            ax2.set_ylabel(r'$z$ (m)')
            ax2.set_ylim([0,1500])
            ax2.set_xlim([291,297])
            ax2.legend(loc=4)
            #
            fig.tight_layout()
            fig.savefig(figtitle,format='png',dpi=300)
            p.close(fig)  
    
        else:
            for i in range(len(p_factors)):
                ax.plot(theta_profiles[i],zt,label='$p=%4.1f$'%p_factors[i])
            return ax
    
    def plot_spread(cbl,plot='spread',ax=None):
    
        # Read relevant dimensions
        times = cbl.history['0000']['time']
        zt = cbl.zt
        nz = cbl.nz
        ntimes = times.size
        nens = cbl.nens
    
        # Reconstruct_history
        theta = np.zeros((nens,nz,ntimes))+np.nan
        for k in range(nens):
            theta[k,:,:]= cbl.history['%04u'%k]['theta']
    
        # Compute spread
        if plot=='spread':
            bkgd = theta.std(axis=0)
        elif plot=='mean':
            bkgd = theta.mean(axis=0)
    
        # Plot
        zmax = 2000
        ncont = 13
    
        if ax is not None:
            # Make plots
            c = ax.pcolormesh(times/3600.,zt,bkgd)
            ax.contour(times/3600.,zt,theta.mean(axis=0),
                        np.linspace(cbl.theta_0,cbl.theta_0+cbl.gamma*zmax,ncont),
                        colors='white',
                        linestyles='--',
                        linewidths=0.75)
            ax.set_ylim([0,zmax])
            return ax,c
    
    def plot_diagnostics(experiments_pe,experiments_nope,labels,filename):
        
        linecolors = p.rcParams['axes.prop_cycle'].by_key()['color']
    
        fig, [[ax1, ax2],[ax3, ax4]] = p.subplots(2,2,constrained_layout=True)
        fig.set_size_inches(6,4)
        z = experiments_pe[0].obs_coordinates
        z_pbl = z*1.
        z_pbl[z>1500] = np.nan
        for i in range(len(experiments_pe)):
            i1 = experiments_pe[i].dg
            i2 = experiments_nope[i].dg
            ax1.plot(i1.aRMSE_t,z,label=labels[i],color=linecolors[i])
            ax1.plot(i2.aRMSE_t,z,color=linecolors[i],dashes=[3,1],alpha=0.3)
            #
            ax2.plot(i1.bRMSE_t,z,label=labels[i],color=linecolors[i])
            ax2.plot(i2.bRMSE_t,z,color=linecolors[i],dashes=[3,1],alpha=0.3)
            #
            ax3.plot(i1.bRMSE_t-i1.aRMSE_t,z,label=labels[i],color=linecolors[i])
            ax3.plot(i2.bRMSE_t-i2.aRMSE_t,z,color=linecolors[i],dashes=[3,1],alpha=0.3)
            #
            ax4.plot(i1.bSprd_t/i1.bRMSE_t,z_pbl,label=labels[i],color=linecolors[i])
            ax4.plot(i2.bSprd_t/i2.bRMSE_t,z_pbl,color=linecolors[i],dashes=[3,1],alpha=0.3)
        ax1.set_title('a) Analysis error')
        ax1.set_xlabel(r'RMSE$^a_\theta$')
        ax2.set_title('b) First-guess error')
        ax2.set_xlabel(r'RMSE$^b_\theta$')
        ax3.set_title('c) Error reduction')
        ax3.set_xlabel(r'RMSE$^b_\theta-$RMSE$^a_\theta$')
        ax4.set_title('d) Spread-error consistency')
        ax4.set_xlabel(r'$\sigma^b_\theta$/RMSE$^b_\theta$')
        ax1.set_ylabel('height (m)')
        ax3.set_ylabel('height (m)')
        #
        ax4.axvline(x=1,color='k',linewidth=0.5,dashes=[3,1])
        ax2.sharey(ax1)
        ax4.sharey(ax3)
        p.setp(ax2.get_yticklabels(), visible=False)
        p.setp(ax4.get_yticklabels(), visible=False)
        #
        fig.savefig(filename,format='png',dpi=300)
        p.close(fig)