diff --git a/sbmtm.py b/sbmtm.py
new file mode 100644
index 0000000000000000000000000000000000000000..94a7079d28328246693cc501b41ba8c9adaf766d
--- /dev/null
+++ b/sbmtm.py
@@ -0,0 +1,671 @@
+from __future__ import print_function
+import pandas as pd
+import numpy as np
+import os,sys,argparse
+import matplotlib.pyplot as plt
+from collections import Counter,defaultdict
+import pickle
+import graph_tool.all as gt
+
+
+class sbmtm():
+    '''
+    Class for topic-modeling with sbm's.
+    '''
+
+    def __init__(self):
+        self.g = None ## network
+
+        self.words = [] ## list of word nodes
+        self.documents = [] ## list of document nodes
+
+        self.state = None ## inference state from graphtool
+        self.groups = {} ## results of group membership from inference
+        self.mdl = np.nan ## minimum description length of inferred state
+        self.L = np.nan ## number of levels in hierarchy
+
+    def make_graph(self,list_texts, documents = None, counts=True, n_min = None):
+        '''
+        Load a corpus and generate the word-document network
+
+        optional arguments:
+        - documents: list of str, titles of documents
+        - counts: save edge-multiplicity as counts (default: True)
+        - n_min, int: filter all word-nodes with less than n_min counts (default None)
+        '''
+        D = len(list_texts)
+
+        ## if there are no document titles, we assign integers 0,...,D-1
+        ## otherwise we use supplied titles
+        if documents == None:
+            list_titles = [str(h) for h in range(D)]
+        else:
+            list_titles = documents
+
+        ## make a graph
+        ## create a graph
+        g = gt.Graph(directed=False)
+        ## define node properties
+        ## name: docs - title, words - 'word'
+        ## kind: docs - 0, words - 1
+        name = g.vp["name"] = g.new_vp("string")
+        kind = g.vp["kind"] = g.new_vp("int")
+        if counts:
+            ecount = g.ep["count"] = g.new_ep("int")
+
+        docs_add = defaultdict(lambda: g.add_vertex())
+        words_add = defaultdict(lambda: g.add_vertex())
+
+        ## add all documents first
+        for i_d in range(D):
+            title = list_titles[i_d]
+            d=docs_add[title]
+
+        ## add all documents and words as nodes
+        ## add all tokens as links
+        for i_d in range(D):
+            title = list_titles[i_d]
+            text = list_texts[i_d]
+
+            d=docs_add[title]
+            name[d] = title
+            kind[d] = 0
+            c=Counter(text)
+            for word,count in c.items():
+                w=words_add[word]
+                name[w] = word
+                kind[w] = 1
+                if counts:
+                    e = g.add_edge(d, w)
+                    ecount[e] = count
+                else:
+                    for n in range(count):
+                        g.add_edge(d,w)
+
+        ## filter word-types with less than n_min counts
+        if n_min is not None:
+            v_n = g.new_vertex_property("int")
+            for v in g.vertices():
+                v_n[v] = v.out_degree()
+
+            v_filter =  g.new_vertex_property("bool")
+            for v in g.vertices():
+                if v_n[v] < n_min and g.vp['kind'][v]==1:
+                    v_filter[v] = False
+                else:
+                    v_filter[v] = True
+            g.set_vertex_filter(v_filter)
+            g.purge_vertices()
+            g.clear_filters()
+
+
+        self.g = g
+        self.words = [ g.vp['name'][v] for v in  g.vertices() if g.vp['kind'][v]==1   ]
+        self.documents = [ g.vp['name'][v] for v in  g.vertices() if g.vp['kind'][v]==0   ]
+
+    def make_graph_from_BoW_df(self, df, counts=True, n_min=None):
+        """
+        Load a graph from a Bag of Words DataFrame
+
+        arguments
+        -----------
+        df should be a DataFrame with where df.index is a list of words and df.columns a list of documents
+
+        optional arguments:
+        - counts: save edge-multiplicity as counts (default: True)
+        - n_min, int: filter all word-nodes with less than n_min counts (default None)
+
+        :type df: DataFrame
+        """
+        # make a graph
+        g = gt.Graph(directed=False)
+        ## define node properties
+        ## name: docs - title, words - 'word'
+        ## kind: docs - 0, words - 1
+        name = g.vp["name"] = g.new_vp("string")
+        kind = g.vp["kind"] = g.new_vp("int")
+        if counts:
+            ecount = g.ep["count"] = g.new_ep("int")
+
+        X = df.values
+
+        # add all documents and words as nodes
+        # add all tokens as links
+        X = scipy.sparse.coo_matrix(X)
+
+        if not counts and X.dtype != int:
+            X_int = X.astype(int)
+            if not np.allclose(X.data, X_int.data):
+                raise ValueError('Data must be integer if '
+                                 'weighted_edges=False')
+            X = X_int
+
+        docs_add = defaultdict(lambda: g.add_vertex())
+        words_add = defaultdict(lambda: g.add_vertex())
+
+        D = len(df.columns)
+        ## add all documents first
+        for i_d in range(D):
+            title = df.columns[i_d]
+            d = docs_add[title]
+            name[d] = title
+            kind[d] = 0
+
+        ## add all words
+        for i_d in range(len(df.index)):
+            word = df.index[i_d]
+            w = words_add[word]
+            name[w] = word
+            kind[w] = 1
+
+        ## add all documents and words as nodes
+        ## add all tokens as links
+        for i_d in range(D):
+            title = df.columns[i_d]
+            text = df[title]
+            for i_w, word, count in zip(range(len(df.index)), df.index, text):
+                if count < 1:
+                    continue
+                if counts:
+                    e = g.add_edge(i_d, D + i_w)
+                    ecount[e] = count
+                else:
+                    for n in range(count):
+                        g.add_edge(i_d, D + i_w)
+
+        ## filter word-types with less than n_min counts
+        if n_min is not None:
+            v_n = g.new_vertex_property("int")
+            for v in g.vertices():
+                v_n[v] = v.out_degree()
+
+            v_filter = g.new_vertex_property("bool")
+            for v in g.vertices():
+                if v_n[v] < n_min and g.vp['kind'][v] == 1:
+                    v_filter[v] = False
+                else:
+                    v_filter[v] = True
+            g.set_vertex_filter(v_filter)
+            g.purge_vertices()
+            g.clear_filters()
+
+        self.g = g
+        self.words = [g.vp['name'][v] for v in g.vertices() if g.vp['kind'][v] == 1]
+        self.documents = [g.vp['name'][v] for v in g.vertices() if g.vp['kind'][v] == 0]
+        return self
+
+    def save_graph(self,filename = 'graph.gt.gz'):
+        '''
+        Save the word-document network generated by make_graph() as filename.
+        Allows for loading the graph without calling make_graph().
+        '''
+        self.g.save(filename)
+
+    def load_graph(self,filename = 'graph.gt.gz'):
+        '''
+        Load a word-document network generated by make_graph() and saved with save_graph().
+        '''
+        self.g = gt.load_graph(filename)
+        self.words = [ self.g.vp['name'][v] for v in  self.g.vertices() if self.g.vp['kind'][v]==1   ]
+        self.documents = [ self.g.vp['name'][v] for v in  self.g.vertices() if self.g.vp['kind'][v]==0   ]
+
+
+    def fit(self,overlap = False, n_init = 1, verbose=False, epsilon=1e-3):
+        '''
+        Fit the sbm to the word-document network.
+        - overlap, bool (default: False). Overlapping or Non-overlapping groups.
+            Overlapping not implemented yet
+        - n_init, int (default:1): number of different initial conditions to run in order to avoid local minimum of MDL.
+        '''
+        g = self.g
+        if g is None:
+            print('No data to fit the SBM. Load some data first (make_graph)')
+        else:
+            if overlap and "count" in g.ep:
+                raise ValueError("When using overlapping SBMs, the graph must be constructed with 'counts=False'")
+            clabel = g.vp['kind']
+
+            state_args = {'clabel': clabel, 'pclabel': clabel}
+            if "count" in g.ep:
+                state_args["eweight"] = g.ep.count
+
+            ## the inference
+            mdl = np.inf ##
+            for i_n_init in range(n_init):
+                base_type = gt.BlockState if not overlap else gt.OverlapBlockState
+                state_tmp = gt.minimize_nested_blockmodel_dl(g,
+                                                             state_args=dict(
+                                                                 base_type=base_type,
+                                                                 **state_args),
+                                                             multilevel_mcmc_args=dict(
+                                                                 verbose=verbose))
+                L = 0
+                for s in state_tmp.levels:
+                    L += 1
+                    if s.get_nonempty_B() == 2:
+                        break
+                state_tmp = state_tmp.copy(bs=state_tmp.get_bs()[:L] + [np.zeros(1)])
+                # state_tmp = state_tmp.copy(sampling=True)
+                # delta = 1 + epsilon
+                # while abs(delta) > epsilon:
+                #     delta = state_tmp.multiflip_mcmc_sweep(niter=10, beta=np.inf)[0]
+                #     print(delta)
+                print(state_tmp)
+
+                mdl_tmp = state_tmp.entropy()
+                if mdl_tmp < mdl:
+                    mdl = 1.0*mdl_tmp
+                    state = state_tmp.copy()
+
+            self.state = state
+            ## minimum description length
+            self.mdl = state.entropy()
+            L = len(state.levels)
+            if L == 2:
+                self.L = 1
+            else:
+                self.L = L-2
+
+
+    def plot(self, filename = None,nedges = 1000):
+        '''
+        Plot the graph and group structure.
+        optional:
+        - filename, str; where to save the plot. if None, will not be saved
+        - nedges, int; subsample  to plot (faster, less memory)
+        '''
+        self.state.draw(layout='bipartite', output=filename,
+                        subsample_edges=nedges, hshortcuts=1, hide=0)
+
+
+    def topics(self, l=0, n=10):
+        '''
+        get the n most common words for each word-group in level l.
+        return tuples (word,P(w|tw))
+        '''
+        # dict_groups = self.groups[l]
+        dict_groups = self.get_groups(l=l)
+
+        Bw = dict_groups['Bw']
+        p_w_tw = dict_groups['p_w_tw']
+
+        words = self.words
+
+        ## loop over all word-groups
+        dict_group_words = {}
+        for tw in range(Bw):
+            p_w_ = p_w_tw[:,tw]
+            ind_w_ = np.argsort(p_w_)[::-1]
+            list_words_tw = []
+            for i in ind_w_[:n]:
+                if p_w_[i] > 0:
+                    list_words_tw+=[(words[i],p_w_[i])]
+                else:
+                    break
+            dict_group_words[tw] = list_words_tw
+        return dict_group_words
+
+    def topicdist(self, doc_index, l=0):
+        # dict_groups =  self.groups[l]
+        dict_groups = self.get_groups(l=l)
+
+        p_tw_d = dict_groups['p_tw_d']
+        list_topics_tw = []
+        for tw,p_tw in enumerate(p_tw_d[:,doc_index]):
+                list_topics_tw += [(tw,p_tw)]
+        return list_topics_tw
+
+    def clusters(self,l=0,n=10):
+        '''
+        Get n 'most common' documents from each document cluster.
+        most common refers to largest contribution in group membership vector.
+        For the non-overlapping case, each document belongs to one and only one group with prob 1.
+
+        '''
+        # dict_groups = self.groups[l]
+        dict_groups = self.get_groups(l=l)
+        Bd = dict_groups['Bd']
+        p_td_d = dict_groups['p_td_d']
+
+        docs = self.documents
+        ## loop over all word-groups
+        dict_group_docs = {}
+        for td in range(Bd):
+            p_d_ = p_td_d[td,:]
+            ind_d_ = np.argsort(p_d_)[::-1]
+            list_docs_td = []
+            for i in ind_d_[:n]:
+                if p_d_[i] > 0:
+                    list_docs_td+=[(docs[i],p_d_[i])]
+                else:
+                    break
+            dict_group_docs[td] = list_docs_td
+        return dict_group_docs
+
+    def clusters_query(self,doc_index,l=0):
+        '''
+        Get all documents in the same group as the query-document.
+        Note: Works only for non-overlapping model.
+        For overlapping case, we need something else.
+        '''
+        # dict_groups = self.groups[l]
+        dict_groups = self.get_groups(l=l)
+        Bd = dict_groups['Bd']
+        p_td_d = dict_groups['p_td_d']
+
+        documents = self.documents
+        ## loop over all word-groups
+        dict_group_docs = {}
+        td = np.argmax(p_td_d[:,doc_index])
+
+        list_doc_index_sel = np.where(p_td_d[td,:]==1)[0]
+
+        list_doc_query = []
+
+        for doc_index_sel in list_doc_index_sel:
+            if doc_index != doc_index_sel:
+                list_doc_query += [(doc_index_sel,documents[doc_index_sel])]
+
+        return list_doc_query
+
+
+    def group_membership(self,l=0):
+        '''
+        Return the group-membership vectors for
+            - document-nodes, p_td_d, array with shape Bd x D
+            - word-nodes, p_tw_w, array with shape Bw x V
+        It gives the probability of a nodes belonging to one of the groups.
+        '''
+        # dict_groups = self.groups[l]
+        dict_groups = self.get_groups(l=l)
+        p_tw_w = dict_groups['p_tw_w']
+        p_td_d = dict_groups['p_td_d']
+        return p_td_d,p_tw_w
+
+
+    def print_topics(self,l=0,format='csv',path_save = ''):
+        '''
+        Print topics, topic-distributions, and document clusters for a given level in the hierarchy.
+        format: csv (default) or html
+        '''
+        V=self.get_V()
+        D=self.get_D()
+
+        ## topics
+        dict_topics = self.topics(l=l,n=-1)
+
+        list_topics = sorted(list(dict_topics.keys()))
+        list_columns = ['Topic %s'%(t+1) for t in list_topics]
+
+        T = len(list_topics)
+        df = pd.DataFrame(columns = list_columns,index=range(V))
+
+
+        for t in list_topics:
+            list_w = [h[0] for h in dict_topics[t]]
+            V_t = len(list_w)
+            df.iloc[:V_t,t] = list_w
+        df=df.dropna(how='all',axis=0)
+        if format == 'csv':
+            fname_save = 'topsbm_level_%s_topics.csv'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_csv(filename,index=False,na_rep='')
+        elif format == 'html':
+            fname_save = 'topsbm_level_%s_topics.html'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_html(filename,index=False,na_rep='')
+        elif format=='tsv':
+            fname_save = 'topsbm_level_%s_topics.tsv'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_csv(filename,index=False,na_rep='',sep='\t')
+        else:
+            pass
+
+        ## topic distributions
+        list_columns = ['i_doc','doc']+['Topic %s'%(t+1) for t in list_topics]
+        df = pd.DataFrame(columns=list_columns,index=range(D))
+        for i_doc in range(D):
+            list_topicdist = self.topicdist(i_doc,l=l)
+            df.iloc[i_doc,0] = i_doc
+            df.iloc[i_doc,1] = self.documents[i_doc]
+            df.iloc[i_doc,2:] = [h[1] for h in list_topicdist]
+        df=df.dropna(how='all',axis=1)
+        if format == 'csv':
+            fname_save = 'topsbm_level_%s_topic-dist.csv'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_csv(filename,index=False,na_rep='')
+        elif format == 'html':
+            fname_save = 'topsbm_level_%s_topic-dist.html'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_html(filename,index=False,na_rep='')
+        else:
+            pass
+
+        ## doc-groups
+
+        dict_clusters = self.clusters(l=l,n=-1)
+
+        list_clusters = sorted(list(dict_clusters.keys()))
+        list_columns = ['Cluster %s'%(t+1) for t in list_clusters]
+
+        T = len(list_clusters)
+        df = pd.DataFrame(columns = list_columns,index=range(D))
+
+
+        for t in list_clusters:
+            list_d = [h[0] for h in dict_clusters[t]]
+            D_t = len(list_d)
+            df.iloc[:D_t,t] = list_d
+        df=df.dropna(how='all',axis=0)
+        if format == 'csv':
+            fname_save = 'topsbm_level_%s_clusters.csv'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_csv(filename,index=False,na_rep='')
+        elif format == 'html':
+            fname_save = 'topsbm_level_%s_clusters.html'%(l)
+            filename = os.path.join(path_save,fname_save)
+            df.to_html(filename,index=False,na_rep='')
+        else:
+            pass
+
+    ###########
+    ########### HELPER FUNCTIONS
+    ###########
+    ## get group-topic statistics
+    def get_groups(self,l=0):
+        '''
+        extract statistics on group membership of nodes form the inferred state.
+        return dictionary
+        - B_d, int, number of doc-groups
+        - B_w, int, number of word-groups
+        - p_tw_w, array B_w x V; word-group-membership:
+             prob that word-node w belongs to word-group tw: P(tw | w)
+        - p_td_d, array B_d x D; doc-group membership:
+             prob that doc-node d belongs to doc-group td: P(td | d)
+        - p_w_tw, array V x B_w; topic distribution:
+             prob of word w given topic tw P(w | tw)
+        - p_tw_d, array B_w x d; doc-topic mixtures:
+             prob of word-group tw in doc d P(tw | d)
+        '''
+        V = self.get_V()
+        D = self.get_D()
+        N = self.get_N()
+
+        g = self.g
+        state = self.state
+        state_l = state.project_level(l).copy(overlap=True)
+        state_l_edges = state_l.get_edge_blocks() ## labeled half-edges
+
+        counts = 'count' in self.g.ep.keys()
+
+        ## count labeled half-edges, group-memberships
+        B = state_l.get_B()
+        n_wb = np.zeros((V,B)) ## number of half-edges incident on word-node w and labeled as word-group tw
+        n_db = np.zeros((D,B)) ## number of half-edges incident on document-node d and labeled as document-group td
+        n_dbw = np.zeros((D,B)) ## number of half-edges incident on document-node d and labeled as word-group td
+
+        for e in g.edges():
+            z1,z2 = state_l_edges[e]
+            v1 = e.source()
+            v2 = e.target()
+            if counts:
+                weight = g.ep["count"][e]
+            else:
+                weight = 1
+            n_db[int(v1), z1] += weight
+            n_dbw[int(v1), z2] += weight
+            n_wb[int(v2) - D, z2] += weight
+
+        p_w = np.sum(n_wb,axis=1)/float(np.sum(n_wb))
+
+        ind_d = np.where(np.sum(n_db,axis=0)>0)[0]
+        Bd = len(ind_d)
+        n_db = n_db[:,ind_d]
+
+        ind_w = np.where(np.sum(n_wb,axis=0)>0)[0]
+        Bw = len(ind_w)
+        n_wb = n_wb[:,ind_w]
+
+        ind_w2 = np.where(np.sum(n_dbw,axis=0)>0)[0]
+        n_dbw = n_dbw[:,ind_w2]
+
+        ## group-membership distributions
+        # group membership of each word-node P(t_w | w)
+        p_tw_w = (n_wb/np.sum(n_wb,axis=1)[:,np.newaxis]).T
+
+        # group membership of each doc-node P(t_d | d)
+        p_td_d = (n_db/np.sum(n_db,axis=1)[:,np.newaxis]).T
+
+        ## topic-distribution for words P(w | t_w)
+        p_w_tw = n_wb/np.sum(n_wb,axis=0)[np.newaxis,:]
+
+        ## Mixture of word-groups into documetns P(t_w | d)
+        p_tw_d = (n_dbw/np.sum(n_dbw,axis=1)[:,np.newaxis]).T
+
+
+        result = {}
+        result['Bd'] = Bd
+        result['Bw'] = Bw
+        result['p_tw_w'] = p_tw_w
+        result['p_td_d'] = p_td_d
+        result['p_w_tw'] = p_w_tw
+        result['p_tw_d'] = p_tw_d
+
+        return result
+
+    ### helper functions
+
+    def get_V(self):
+        '''
+        return number of word-nodes == types
+        '''
+        return int(np.sum(self.g.vp['kind'].a==1)) # no. of types
+    def get_D(self):
+        '''
+        return number of doc-nodes == number of documents
+        '''
+        return int(np.sum(self.g.vp['kind'].a==0)) # no. of types
+    def get_N(self):
+        '''
+        return number of edges == tokens
+        '''
+        return int(self.g.num_edges()) # no. of types
+
+    def group_to_group_mixture(self,l=0,norm=True):
+        V = self.get_V()
+        D = self.get_D()
+        N = self.get_N()
+
+        g = self.g
+        state = self.state
+        state_l = state.project_level(l).copy(overlap=True)
+        state_l_edges = state_l.get_edge_blocks() ## labeled half-edges
+
+        ## count labeled half-edges, group-memberships
+        B = state_l.get_B()
+        n_td_tw = np.zeros((B,B))
+
+        counts = 'count' in self.g.ep.keys()
+
+        for e in g.edges():
+            z1,z2 = state_l_edges[e]
+            if counts:
+                n_td_tw[z1 , z2] += g.ep["count"][e]
+            else:
+                n_td_tw[z1, z2] += 1
+
+
+        ind_d = np.where(np.sum(n_td_tw,axis=1)>0)[0]
+        Bd = len(ind_d)
+        ind_w = np.where(np.sum(n_td_tw,axis=0)>0)[0]
+        Bw = len(ind_w)
+
+        n_td_tw = n_td_tw[:Bd,Bd:]
+        if norm == True:
+            return n_td_tw/np.sum(n_td_tw)
+        else:
+            return n_td_tw
+
+    def pmi_td_tw(self,l=0):
+        '''
+        Point-wise mutual information between topic-groups and doc-groups, S(td,tw)
+        This is an array of shape Bd x Bw.
+
+        It corresponds to
+        S(td,tw) = log P(tw | td) / \tilde{P}(tw | td) .
+
+        This is the log-ratio between
+        P(tw | td) == prb of topic tw in doc-group td;
+        \tilde{P}(tw | td) = P(tw) expected prob of topic tw in doc-group td under random null model.
+        '''
+        p_td_tw = self.group_to_group_mixture(l=l)
+        p_tw_td = p_td_tw.T
+        p_td = np.sum(p_tw_td,axis=0)
+        p_tw = np.sum(p_tw_td,axis=1)
+        pmi_td_tw = np.log(p_tw_td/(p_td*p_tw[:,np.newaxis])).T
+        return pmi_td_tw
+
+
+    def print_summary(self, tofile=True):
+        '''
+        Print hierarchy summary
+        '''
+        if tofile:
+            orig_stdout = sys.stdout
+            f = open('summary.txt', 'w')
+            sys.stdout = f
+            self.state.print_summary()
+            sys.stdout = orig_stdout
+            f.close()
+        else:
+            self.state.print_summary()
+
+    def plot_topic_dist(self, l):
+        groups = self.groups[l]
+        p_w_tw = groups['p_w_tw']
+        fig=plt.figure(figsize=(12,10))
+        plt.imshow(p_w_tw,origin='lower',aspect='auto',interpolation='none')
+        plt.title(r'Word group membership $P(w | tw)$')
+        plt.xlabel('Topic, tw')
+        plt.ylabel('Word w (index)')
+        plt.colorbar()
+        fig.savefig("p_w_tw_%d.png"%l)
+        p_tw_d = groups['p_tw_d']
+        fig=plt.figure(figsize=(12,10))
+        plt.imshow(p_tw_d,origin='lower',aspect='auto',interpolation='none')
+        plt.title(r'Word group membership $P(tw | d)$')
+        plt.xlabel('Document (index)')
+        plt.ylabel('Topic, tw')
+        plt.colorbar()
+        fig.savefig("p_tw_d_%d.png"%l)
+
+    def save_data(self):
+        for i in range(len(self.state.get_levels())-2)[::-1]:
+            print("Saving level %d"%i)
+            self.print_topics(l=i)
+            self.print_topics(l=i, format='tsv')
+            self.plot_topic_dist(i)
+            e = self.state.get_levels()[i].get_matrix()
+            plt.matshow(e.todense())
+            plt.savefig("mat_%d.png"%i)
+        self.print_summary()