diff --git a/graphics.py b/graphics.py
index 568af059cc7fc69b0f43576917daec151894be8b..56f6fda759478d5027a20dd8563057fb7154fa7e 100644
--- a/graphics.py
+++ b/graphics.py
@@ -235,242 +235,118 @@ def plot_CBL_PE(exp,figtitle,parameter_id=0,plot_spread=False,ax=None):
         else:
             return ax1
 
-def plot_CBL_identifiability(cbl,sigma_o,figtitle,ax=None):
+def plot_CBL_identifiability(cbl,sigma_o,ax=None):
 
     cmap = p.get_cmap('RdBu_r')
     new_cmap = truncate_colormap(cmap, 0.5, 1.0)
-
     zmax = 2000
     ncont = 13
 
-    work_in_observation_space = True
-
-    if work_in_observation_space:
-        if isinstance(cbl,experiment):
-            # Read relevant dimensions
-            zt = cbl.da.zt
-            nz = cbl.da.nz
-            nens = cbl.da.nens
-            nobs = cbl.da.observations.shape[1]
-            npar = cbl.nr.parameter_number
-            ntimes = cbl.da.ncycles
-            times = cbl.da.time
-            zobs = cbl.da.obs_coordinates
-            theta = cbl.da.model_equivalents  # ntim*nobs*nens
-            pfac = cbl.da.backgrounds[:,-npar,:] # ntim*nx*nens
-            pfac = cbl.nr.parameter_transform[0](pfac,kind='inv')
-            if hasattr(cbl.da,'cov_py'):
-                sigma2_yb = cbl.da.cov_yy # ntim*nobs
-                sigma2_p = np.squeeze(cbl.da.cov_pp) # ntim*nobs(*npar)
-                covariance = np.squeeze(cbl.da.cov_py) # ntim*nobs(*npar)
-                correlation = covariance/(sigma2_yb*sigma2_p)**0.5
-                innov = cbl.da.innov
-                deltap = np.squeeze(cbl.da.increments) # ntim*nobs(*npar)
+    # cbl must be an "experiment" instance
+    # Read relevant dimensions
+    zt = cbl.da.zt
+    nz = cbl.da.nz
+    nens = cbl.da.nens
+    nobs = cbl.da.observations.shape[1]
+    npar = cbl.nr.parameter_number
+    ntimes = cbl.da.ncycles
+    times = cbl.da.time
+    zobs = cbl.da.obs_coordinates
+    theta = cbl.da.model_equivalents  # ntim*nobs*nens
+    pfac = cbl.da.backgrounds[:,-npar,:] # ntim*nx*nens
+    pfac = cbl.nr.parameter_transform[0](pfac,kind='inv')
+    if hasattr(cbl.da,'cov_py'):
+        sigma2_yb = cbl.da.cov_yy # ntim*nobs
+        sigma2_p = np.squeeze(cbl.da.cov_pp) # ntim*nobs(*npar)
+        covariance = np.squeeze(cbl.da.cov_py) # ntim*nobs(*npar)
+        correlation = covariance/(sigma2_yb*sigma2_p)**0.5
+        innov = cbl.da.innov
+        deltap = np.squeeze(cbl.da.increments) # ntim*nobs(*npar)
+
+    # Shorthand for plots
+    A = correlation.T
+    B = ( innov*(sigma2_p*sigma2_yb)**0.5/(sigma2_yb+sigma_o**2) ).T
+    C = deltap.T
+    D = ( (sigma2_p/sigma2_yb)**0.5 ).T
+    E = ( innov*sigma2_yb/(sigma2_yb+sigma_o**2) ).T
+
+    # For plotting
+    contours = np.linspace(cbl.nr.theta_0,cbl.nr.theta_0+cbl.nr.gamma*zmax,ncont)
+
+    # Make plots
+    if ax is None:
+        pass
 
-        else:
-            pass
+    elif type(ax) is p.Axes:
 
-        # For plotting
-        contours = np.linspace(cbl.nr.theta_0,cbl.nr.theta_0+cbl.nr.gamma*zmax,ncont)
+        c0 = ax.pcolormesh(times/3600.,zobs,A,vmin=-1,vmax=1,cmap='RdBu_r')
+        ax.contour(times/3600.,zobs,theta.mean(axis=2).T,
+                    contours,
+                    colors='black',
+                    linestyles='--',
+                    linewidths=0.75)
+        ax.set_ylim([0,zmax])
 
-        # Make plots
-        if ax is None:
-            pass
-        elif type(ax) is p.Axes:
-            # Make plots
-            c = ax.pcolormesh(times/3600.,zobs,correlation,norm=colors.CenteredNorm(),cmap='RdBu_r')
-            ax.contour(times/3600.,zobs,theta.mean(axis=2).T,
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax.set_ylim([0,zmax])
-            return ax,c
-        elif type(ax) is list and len(ax)==2:
-            # Make plots
-            c0 = ax[0].pcolormesh(times/3600.,zobs,correlation,norm=colors.CenteredNorm(),cmap='RdBu_r')
-            ax[0].contour(times/3600.,zobs,theta.mean(axis=2).T,
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax[0].set_ylim([0,zmax])
-
-            c1 = ax[1].pcolormesh(times/3600.,zobs,np.sqrt(sigma2_p[None,:]/sigma2_yb),cmap=new_cmap)
-            ax[1].contour(times/3600.,zobs,theta.mean(axis=2).T,
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax[1].set_ylim([0,zmax])
-            return ax,c0,c1
-        elif type(ax) is list and len(ax)==3:
-            # Make plots
-            A = correlation.T
-            B = ( innov*(sigma2_p*sigma2_yb)**0.5/(sigma2_yb+sigma_o**2) ).T
-            C = deltap.T
-
-            c0 = ax[0].pcolormesh(times/3600.,zobs,A,vmin=-1,vmax=1,cmap='RdBu_r')
-            ax[0].contour(times/3600.,zobs,theta.mean(axis=2).T,
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax[0].set_ylim([0,zmax])
-
-            c1 = ax[1].pcolormesh(times/3600.,zobs,B,norm=colors.CenteredNorm(),cmap='RdBu_r')#,cmap=new_cmap)
-            ax[1].contour(times/3600.,zobs,theta.mean(axis=2).T,
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax[1].set_ylim([0,zmax])
-
-            c2 = ax[2].pcolormesh(times/3600.,zobs,C,norm=colors.CenteredNorm(),cmap='RdBu_r')
-            ax[2].contour(times/3600.,zobs,theta.mean(axis=2).T,
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax[2].set_ylim([0,zmax])
-
-            return ax,c0,c1,c2
+        return ax,c0
 
-    else:
-        if isinstance(cbl,experiment):
-            # Read relevant dimensions
-            zt = cbl.da.zt
-            nz = cbl.da.nz
-            nens = cbl.da.nens
-            nobs = cbl.da.observations.shape[1]
-            npar = cbl.nr.parameter_number
-
-            # Reconstruct time array
-            times = np.array([])
-            beg = 0
-            for i in range(len(cbl.da.history)):
-                times = np.append(times,cbl.da.history[i]['0000']['time']+beg)
-                beg = beg+(cbl.da.history[i]['0000']['time'].size-1)*cbl.nr.dt
-            ntimes = times.size
-
-            # Reconstruct history
-            theta = np.zeros((nens,nz,ntimes))+np.nan
-            pfac = np.zeros((nens,ntimes))+np.nan
-            beg = 0
-            for i in range(len(cbl.da.history)):
-                end = cbl.da.history[i]['0000']['time'].size
-                for k in range(nens):
-                    theta[k,:,beg:beg+end]= cbl.da.history[i]['%04u'%k]['theta']
-                    pfac[k,beg:beg+end]= cbl.da.backgrounds[i,-npar,k]
-                beg = beg+end
-            pfac = cbl.nr.parameter_transform[0](pfac,kind='inv')
-
-            # Compute Pearson correlation coefficient and Kalman Gain
-            # (a little noise is added to the denominator to avoid singularities)
-            correlation = np.zeros((nz,ntimes))+np.nan
-            kalman_gain = np.zeros((nz,ntimes))+np.nan
-            for i in range(nz):
-                for j in range(ntimes):
-                    covmat=np.cov(pfac[:,j],theta[:,i,j])
-                    correlation[i,j] = covmat[0,1] / (1e-9+np.sqrt(covmat[0,0]*covmat[1,1]))
-                    kalman_gain[i,j] = covmat[0,1] / (1e-9+covmat[1,1]+sigma_o**2)
-
-            # For plotting
-            contours = np.linspace(cbl.nr.theta_0,cbl.nr.theta_0+cbl.nr.gamma*zmax,ncont)
+    elif type(ax) is list and len(ax)==2:
 
-        else:
-            # Read relevant dimensions
-            zt = cbl.zt
-            nz = cbl.nz
-            nens = cbl.nens
-            times = cbl.history['0000']['time']
-            ntimes = times.size
-
-            # Reconstruct_history
-            theta = np.zeros((nens,nz,ntimes))+np.nan
-            for k in range(nens):
-                theta[k,:,:]= cbl.history['%04u'%k]['theta']
-
-            # Read parameter values
-            if hasattr(cbl,'initial_perturbed_parameters'):
-                pfac = cbl.parameter_transform[0](cbl.initial_perturbed_parameters[0],kind='inv')
-            else:
-                pfac = np.ones(nens)*cbl.pfac
-
-            # Compute proxies of Pearson correlation coefficient and Kalman Gain
-            # (a little noise is added to the denominator to avoid singularities)
-            correlation = np.zeros((nz,ntimes))+np.nan
-            kalman_gain = np.zeros((nz,ntimes))+np.nan
-            for i in range(nz):
-                for j in range(ntimes):
-                    covmat=np.cov(pfac,theta[:,i,j])
-                    correlation[i,j] = covmat[0,1] / (1e-9+np.sqrt(covmat[0,0]*covmat[1,1]))
-                    kalman_gain[i,j] = covmat[0,1] / (1e-9+covmat[1,1]+sigma_o**2)
-            
-            # For plotting
-            contours = np.linspace(cbl.theta_0,cbl.theta_0+cbl.gamma*zmax,ncont)
+        c0 = ax[0].pcolormesh(times/3600.,zobs,D,cmap=new_cmap)
+        ax[0].contour(times/3600.,zobs,theta.mean(axis=2).T,
+                    contours,
+                    colors='black',
+                    linestyles='--',
+                    linewidths=0.75)
+        ax[0].set_ylim([0,zmax])
 
-        if ax is None:
+        c1 = ax[1].pcolormesh(times/3600.,zobs,E,norm=colors.CenteredNorm(),cmap='RdBu_r')
+        ax[1].contour(times/3600.,zobs,theta.mean(axis=2).T,
+                    contours,
+                    colors='black',
+                    linestyles='--',
+                    linewidths=0.75)
+        ax[1].set_ylim([0,zmax])
 
-            # Make plots
-            fig = p.figure(151)
-            fig.set_size_inches(9,3)
-            #
-            ax1 = fig.add_subplot(1,3,1)
-            c1 = ax1.pcolormesh(times,zt,theta.mean(axis=0),vmin=cbl.theta_0,vmax=cbl.theta_0+cbl.gamma*zmax)
-            ax1.contour(times,zt,theta.mean(axis=0),
-                        np.linspace(cbl.theta_0,cbl.theta_0+cbl.gamma*zmax,ncont),
-                        colors='white',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax1.set_ylim([0,zmax])
-            ax1.set_xlabel(r'time')
-            ax1.set_ylabel(r'$z$ (m)')
-            ax1.set_title(r'ensemble mean $\theta$ (K)')
-            p.colorbar(c1)
-            #
-            ax2 = fig.add_subplot(1,3,2)
-            c2 = ax2.pcolormesh(times,zt,theta.std(axis=0))
-            ax2.set_ylim([0,zmax])
-            ax2.set_xlabel(r'time')
-            ax2.set_ylabel(r'$z$ (m)')
-            ax2.set_title(r'ensemble sdev $\theta$ (K)')
-            p.colorbar(c2)
-            #
-            ax3 = fig.add_subplot(1,3,3,sharey=ax1)
-            c3 = ax3.pcolormesh(times,zt,correlation,vmin=-1,vmax=1,cmap='RdBu_r')
-            ax3.set_xlabel(r'time')
-            ax3.set_title(r'$p$-$\theta$ correlation')
-            p.colorbar(c3)
-            #
-            p.setp(ax2.get_yticklabels(), visible=False)
-            p.setp(ax3.get_yticklabels(), visible=False)
-            fig.tight_layout()
-            fig.savefig(figtitle,format='png',dpi=300)
-            p.close(fig)
-        else:
-            # Make plots
-            c = ax.pcolormesh(times/3600.,zt,kalman_gain,norm=colors.CenteredNorm(),cmap='RdBu_r')
-            #c = ax.pcolormesh(times/3600.,zt,kalman_gain,vmin=-5, vmax=5,cmap='RdBu_r')
-            ax.contour(times/3600.,zt,theta.mean(axis=0),
-                        contours,
-                        colors='black',
-                        linestyles='--',
-                        linewidths=0.75)
-            ax.set_ylim([0,zmax])
-
-            make_parameter_state_scatterplots = False
-            if make_parameter_state_scatterplots:
-                fig, [[ax1, ax2],[ax3, ax4]] = p.subplots(2,2,constrained_layout=True)
-                fig.set_size_inches(4,4)
-                ax3.scatter(theta[:,0,0] , pfac) #ll
-                ax4.scatter(theta[:,0,-1], pfac) #lr
-                ax1.scatter(theta[:,-1,0], pfac) #ul
-                ax2.scatter(theta[:,-1,-1],pfac) #ur
-                fig.savefig('scatter.png',format='png',dpi=300)
-                p.close(fig)
-
-            return ax,c
+        return ax,c0,c1
+
+    elif type(ax) is list and len(ax)==3:
+
+        c0 = ax[0].pcolormesh(times/3600.,zobs,A,vmin=-1,vmax=1,cmap='RdBu_r')
+        ax[0].contour(times/3600.,zobs,theta.mean(axis=2).T,
+                    contours,
+                    colors='black',
+                    linestyles='--',
+                    linewidths=0.75)
+        ax[0].set_ylim([0,zmax])
+
+        c1 = ax[1].pcolormesh(times/3600.,zobs,B,norm=colors.CenteredNorm(),cmap='RdBu_r')
+        ax[1].contour(times/3600.,zobs,theta.mean(axis=2).T,
+                    contours,
+                    colors='black',
+                    linestyles='--',
+                    linewidths=0.75)
+        ax[1].set_ylim([0,zmax])
+
+        c2 = ax[2].pcolormesh(times/3600.,zobs,C,norm=colors.CenteredNorm(),cmap='RdBu_r')
+        ax[2].contour(times/3600.,zobs,theta.mean(axis=2).T,
+                    contours,
+                    colors='black',
+                    linestyles='--',
+                    linewidths=0.75)
+        ax[2].set_ylim([0,zmax])
+
+    make_parameter_state_scatterplots = False
+    if make_parameter_state_scatterplots:
+        fig, [[ax1, ax2],[ax3, ax4]] = p.subplots(2,2,constrained_layout=True)
+        fig.set_size_inches(4,4)
+        ax3.scatter(theta[:,0,0] , pfac) #ll
+        ax4.scatter(theta[:,0,-1], pfac) #lr
+        ax1.scatter(theta[:,-1,0], pfac) #ul
+        ax2.scatter(theta[:,-1,-1],pfac) #ur
+        fig.savefig('p_scatter_%s.png'%cbl.label,format='png',dpi=300)
+        p.close(fig)
+
+    return ax,c0,c1,c2
 
 def plot_p(p_factors,theta_profiles,zt,figtitle,ax=None):
 
@@ -571,7 +447,6 @@ def plot_diagnostics(experiments_pe,experiments_nope,labels,filename):
     ax1.set_ylabel('height (m)')
     ax3.set_ylabel('height (m)')
     #
-    #ax2.legend(frameon=False)
     ax4.axvline(x=1,color='k',linewidth=0.5,dashes=[3,1])
     ax2.sharey(ax1)
     ax4.sharey(ax3)