Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
ChineseTopSBM
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Christian Göbel
ChineseTopSBM
Commits
4a32b269
Commit
4a32b269
authored
3 years ago
by
Christian Göbel
Browse files
Options
Downloads
Patches
Plain Diff
initial commit
parent
c2456ea5
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
sbmtm.py
+671
-0
671 additions, 0 deletions
sbmtm.py
with
671 additions
and
0 deletions
sbmtm.py
0 → 100644
+
671
−
0
View file @
4a32b269
from
__future__
import
print_function
import
pandas
as
pd
import
numpy
as
np
import
os
,
sys
,
argparse
import
matplotlib.pyplot
as
plt
from
collections
import
Counter
,
defaultdict
import
pickle
import
graph_tool.all
as
gt
class
sbmtm
():
'''
Class for topic-modeling with sbm
'
s.
'''
def
__init__
(
self
):
self
.
g
=
None
## network
self
.
words
=
[]
## list of word nodes
self
.
documents
=
[]
## list of document nodes
self
.
state
=
None
## inference state from graphtool
self
.
groups
=
{}
## results of group membership from inference
self
.
mdl
=
np
.
nan
## minimum description length of inferred state
self
.
L
=
np
.
nan
## number of levels in hierarchy
def
make_graph
(
self
,
list_texts
,
documents
=
None
,
counts
=
True
,
n_min
=
None
):
'''
Load a corpus and generate the word-document network
optional arguments:
- documents: list of str, titles of documents
- counts: save edge-multiplicity as counts (default: True)
- n_min, int: filter all word-nodes with less than n_min counts (default None)
'''
D
=
len
(
list_texts
)
## if there are no document titles, we assign integers 0,...,D-1
## otherwise we use supplied titles
if
documents
==
None
:
list_titles
=
[
str
(
h
)
for
h
in
range
(
D
)]
else
:
list_titles
=
documents
## make a graph
## create a graph
g
=
gt
.
Graph
(
directed
=
False
)
## define node properties
## name: docs - title, words - 'word'
## kind: docs - 0, words - 1
name
=
g
.
vp
[
"
name
"
]
=
g
.
new_vp
(
"
string
"
)
kind
=
g
.
vp
[
"
kind
"
]
=
g
.
new_vp
(
"
int
"
)
if
counts
:
ecount
=
g
.
ep
[
"
count
"
]
=
g
.
new_ep
(
"
int
"
)
docs_add
=
defaultdict
(
lambda
:
g
.
add_vertex
())
words_add
=
defaultdict
(
lambda
:
g
.
add_vertex
())
## add all documents first
for
i_d
in
range
(
D
):
title
=
list_titles
[
i_d
]
d
=
docs_add
[
title
]
## add all documents and words as nodes
## add all tokens as links
for
i_d
in
range
(
D
):
title
=
list_titles
[
i_d
]
text
=
list_texts
[
i_d
]
d
=
docs_add
[
title
]
name
[
d
]
=
title
kind
[
d
]
=
0
c
=
Counter
(
text
)
for
word
,
count
in
c
.
items
():
w
=
words_add
[
word
]
name
[
w
]
=
word
kind
[
w
]
=
1
if
counts
:
e
=
g
.
add_edge
(
d
,
w
)
ecount
[
e
]
=
count
else
:
for
n
in
range
(
count
):
g
.
add_edge
(
d
,
w
)
## filter word-types with less than n_min counts
if
n_min
is
not
None
:
v_n
=
g
.
new_vertex_property
(
"
int
"
)
for
v
in
g
.
vertices
():
v_n
[
v
]
=
v
.
out_degree
()
v_filter
=
g
.
new_vertex_property
(
"
bool
"
)
for
v
in
g
.
vertices
():
if
v_n
[
v
]
<
n_min
and
g
.
vp
[
'
kind
'
][
v
]
==
1
:
v_filter
[
v
]
=
False
else
:
v_filter
[
v
]
=
True
g
.
set_vertex_filter
(
v_filter
)
g
.
purge_vertices
()
g
.
clear_filters
()
self
.
g
=
g
self
.
words
=
[
g
.
vp
[
'
name
'
][
v
]
for
v
in
g
.
vertices
()
if
g
.
vp
[
'
kind
'
][
v
]
==
1
]
self
.
documents
=
[
g
.
vp
[
'
name
'
][
v
]
for
v
in
g
.
vertices
()
if
g
.
vp
[
'
kind
'
][
v
]
==
0
]
def
make_graph_from_BoW_df
(
self
,
df
,
counts
=
True
,
n_min
=
None
):
"""
Load a graph from a Bag of Words DataFrame
arguments
-----------
df should be a DataFrame with where df.index is a list of words and df.columns a list of documents
optional arguments:
- counts: save edge-multiplicity as counts (default: True)
- n_min, int: filter all word-nodes with less than n_min counts (default None)
:type df: DataFrame
"""
# make a graph
g
=
gt
.
Graph
(
directed
=
False
)
## define node properties
## name: docs - title, words - 'word'
## kind: docs - 0, words - 1
name
=
g
.
vp
[
"
name
"
]
=
g
.
new_vp
(
"
string
"
)
kind
=
g
.
vp
[
"
kind
"
]
=
g
.
new_vp
(
"
int
"
)
if
counts
:
ecount
=
g
.
ep
[
"
count
"
]
=
g
.
new_ep
(
"
int
"
)
X
=
df
.
values
# add all documents and words as nodes
# add all tokens as links
X
=
scipy
.
sparse
.
coo_matrix
(
X
)
if
not
counts
and
X
.
dtype
!=
int
:
X_int
=
X
.
astype
(
int
)
if
not
np
.
allclose
(
X
.
data
,
X_int
.
data
):
raise
ValueError
(
'
Data must be integer if
'
'
weighted_edges=False
'
)
X
=
X_int
docs_add
=
defaultdict
(
lambda
:
g
.
add_vertex
())
words_add
=
defaultdict
(
lambda
:
g
.
add_vertex
())
D
=
len
(
df
.
columns
)
## add all documents first
for
i_d
in
range
(
D
):
title
=
df
.
columns
[
i_d
]
d
=
docs_add
[
title
]
name
[
d
]
=
title
kind
[
d
]
=
0
## add all words
for
i_d
in
range
(
len
(
df
.
index
)):
word
=
df
.
index
[
i_d
]
w
=
words_add
[
word
]
name
[
w
]
=
word
kind
[
w
]
=
1
## add all documents and words as nodes
## add all tokens as links
for
i_d
in
range
(
D
):
title
=
df
.
columns
[
i_d
]
text
=
df
[
title
]
for
i_w
,
word
,
count
in
zip
(
range
(
len
(
df
.
index
)),
df
.
index
,
text
):
if
count
<
1
:
continue
if
counts
:
e
=
g
.
add_edge
(
i_d
,
D
+
i_w
)
ecount
[
e
]
=
count
else
:
for
n
in
range
(
count
):
g
.
add_edge
(
i_d
,
D
+
i_w
)
## filter word-types with less than n_min counts
if
n_min
is
not
None
:
v_n
=
g
.
new_vertex_property
(
"
int
"
)
for
v
in
g
.
vertices
():
v_n
[
v
]
=
v
.
out_degree
()
v_filter
=
g
.
new_vertex_property
(
"
bool
"
)
for
v
in
g
.
vertices
():
if
v_n
[
v
]
<
n_min
and
g
.
vp
[
'
kind
'
][
v
]
==
1
:
v_filter
[
v
]
=
False
else
:
v_filter
[
v
]
=
True
g
.
set_vertex_filter
(
v_filter
)
g
.
purge_vertices
()
g
.
clear_filters
()
self
.
g
=
g
self
.
words
=
[
g
.
vp
[
'
name
'
][
v
]
for
v
in
g
.
vertices
()
if
g
.
vp
[
'
kind
'
][
v
]
==
1
]
self
.
documents
=
[
g
.
vp
[
'
name
'
][
v
]
for
v
in
g
.
vertices
()
if
g
.
vp
[
'
kind
'
][
v
]
==
0
]
return
self
def
save_graph
(
self
,
filename
=
'
graph.gt.gz
'
):
'''
Save the word-document network generated by make_graph() as filename.
Allows for loading the graph without calling make_graph().
'''
self
.
g
.
save
(
filename
)
def
load_graph
(
self
,
filename
=
'
graph.gt.gz
'
):
'''
Load a word-document network generated by make_graph() and saved with save_graph().
'''
self
.
g
=
gt
.
load_graph
(
filename
)
self
.
words
=
[
self
.
g
.
vp
[
'
name
'
][
v
]
for
v
in
self
.
g
.
vertices
()
if
self
.
g
.
vp
[
'
kind
'
][
v
]
==
1
]
self
.
documents
=
[
self
.
g
.
vp
[
'
name
'
][
v
]
for
v
in
self
.
g
.
vertices
()
if
self
.
g
.
vp
[
'
kind
'
][
v
]
==
0
]
def
fit
(
self
,
overlap
=
False
,
n_init
=
1
,
verbose
=
False
,
epsilon
=
1e-3
):
'''
Fit the sbm to the word-document network.
- overlap, bool (default: False). Overlapping or Non-overlapping groups.
Overlapping not implemented yet
- n_init, int (default:1): number of different initial conditions to run in order to avoid local minimum of MDL.
'''
g
=
self
.
g
if
g
is
None
:
print
(
'
No data to fit the SBM. Load some data first (make_graph)
'
)
else
:
if
overlap
and
"
count
"
in
g
.
ep
:
raise
ValueError
(
"
When using overlapping SBMs, the graph must be constructed with
'
counts=False
'"
)
clabel
=
g
.
vp
[
'
kind
'
]
state_args
=
{
'
clabel
'
:
clabel
,
'
pclabel
'
:
clabel
}
if
"
count
"
in
g
.
ep
:
state_args
[
"
eweight
"
]
=
g
.
ep
.
count
## the inference
mdl
=
np
.
inf
##
for
i_n_init
in
range
(
n_init
):
base_type
=
gt
.
BlockState
if
not
overlap
else
gt
.
OverlapBlockState
state_tmp
=
gt
.
minimize_nested_blockmodel_dl
(
g
,
state_args
=
dict
(
base_type
=
base_type
,
**
state_args
),
multilevel_mcmc_args
=
dict
(
verbose
=
verbose
))
L
=
0
for
s
in
state_tmp
.
levels
:
L
+=
1
if
s
.
get_nonempty_B
()
==
2
:
break
state_tmp
=
state_tmp
.
copy
(
bs
=
state_tmp
.
get_bs
()[:
L
]
+
[
np
.
zeros
(
1
)])
# state_tmp = state_tmp.copy(sampling=True)
# delta = 1 + epsilon
# while abs(delta) > epsilon:
# delta = state_tmp.multiflip_mcmc_sweep(niter=10, beta=np.inf)[0]
# print(delta)
print
(
state_tmp
)
mdl_tmp
=
state_tmp
.
entropy
()
if
mdl_tmp
<
mdl
:
mdl
=
1.0
*
mdl_tmp
state
=
state_tmp
.
copy
()
self
.
state
=
state
## minimum description length
self
.
mdl
=
state
.
entropy
()
L
=
len
(
state
.
levels
)
if
L
==
2
:
self
.
L
=
1
else
:
self
.
L
=
L
-
2
def
plot
(
self
,
filename
=
None
,
nedges
=
1000
):
'''
Plot the graph and group structure.
optional:
- filename, str; where to save the plot. if None, will not be saved
- nedges, int; subsample to plot (faster, less memory)
'''
self
.
state
.
draw
(
layout
=
'
bipartite
'
,
output
=
filename
,
subsample_edges
=
nedges
,
hshortcuts
=
1
,
hide
=
0
)
def
topics
(
self
,
l
=
0
,
n
=
10
):
'''
get the n most common words for each word-group in level l.
return tuples (word,P(w|tw))
'''
# dict_groups = self.groups[l]
dict_groups
=
self
.
get_groups
(
l
=
l
)
Bw
=
dict_groups
[
'
Bw
'
]
p_w_tw
=
dict_groups
[
'
p_w_tw
'
]
words
=
self
.
words
## loop over all word-groups
dict_group_words
=
{}
for
tw
in
range
(
Bw
):
p_w_
=
p_w_tw
[:,
tw
]
ind_w_
=
np
.
argsort
(
p_w_
)[::
-
1
]
list_words_tw
=
[]
for
i
in
ind_w_
[:
n
]:
if
p_w_
[
i
]
>
0
:
list_words_tw
+=
[(
words
[
i
],
p_w_
[
i
])]
else
:
break
dict_group_words
[
tw
]
=
list_words_tw
return
dict_group_words
def
topicdist
(
self
,
doc_index
,
l
=
0
):
# dict_groups = self.groups[l]
dict_groups
=
self
.
get_groups
(
l
=
l
)
p_tw_d
=
dict_groups
[
'
p_tw_d
'
]
list_topics_tw
=
[]
for
tw
,
p_tw
in
enumerate
(
p_tw_d
[:,
doc_index
]):
list_topics_tw
+=
[(
tw
,
p_tw
)]
return
list_topics_tw
def
clusters
(
self
,
l
=
0
,
n
=
10
):
'''
Get n
'
most common
'
documents from each document cluster.
most common refers to largest contribution in group membership vector.
For the non-overlapping case, each document belongs to one and only one group with prob 1.
'''
# dict_groups = self.groups[l]
dict_groups
=
self
.
get_groups
(
l
=
l
)
Bd
=
dict_groups
[
'
Bd
'
]
p_td_d
=
dict_groups
[
'
p_td_d
'
]
docs
=
self
.
documents
## loop over all word-groups
dict_group_docs
=
{}
for
td
in
range
(
Bd
):
p_d_
=
p_td_d
[
td
,:]
ind_d_
=
np
.
argsort
(
p_d_
)[::
-
1
]
list_docs_td
=
[]
for
i
in
ind_d_
[:
n
]:
if
p_d_
[
i
]
>
0
:
list_docs_td
+=
[(
docs
[
i
],
p_d_
[
i
])]
else
:
break
dict_group_docs
[
td
]
=
list_docs_td
return
dict_group_docs
def
clusters_query
(
self
,
doc_index
,
l
=
0
):
'''
Get all documents in the same group as the query-document.
Note: Works only for non-overlapping model.
For overlapping case, we need something else.
'''
# dict_groups = self.groups[l]
dict_groups
=
self
.
get_groups
(
l
=
l
)
Bd
=
dict_groups
[
'
Bd
'
]
p_td_d
=
dict_groups
[
'
p_td_d
'
]
documents
=
self
.
documents
## loop over all word-groups
dict_group_docs
=
{}
td
=
np
.
argmax
(
p_td_d
[:,
doc_index
])
list_doc_index_sel
=
np
.
where
(
p_td_d
[
td
,:]
==
1
)[
0
]
list_doc_query
=
[]
for
doc_index_sel
in
list_doc_index_sel
:
if
doc_index
!=
doc_index_sel
:
list_doc_query
+=
[(
doc_index_sel
,
documents
[
doc_index_sel
])]
return
list_doc_query
def
group_membership
(
self
,
l
=
0
):
'''
Return the group-membership vectors for
- document-nodes, p_td_d, array with shape Bd x D
- word-nodes, p_tw_w, array with shape Bw x V
It gives the probability of a nodes belonging to one of the groups.
'''
# dict_groups = self.groups[l]
dict_groups
=
self
.
get_groups
(
l
=
l
)
p_tw_w
=
dict_groups
[
'
p_tw_w
'
]
p_td_d
=
dict_groups
[
'
p_td_d
'
]
return
p_td_d
,
p_tw_w
def
print_topics
(
self
,
l
=
0
,
format
=
'
csv
'
,
path_save
=
''
):
'''
Print topics, topic-distributions, and document clusters for a given level in the hierarchy.
format: csv (default) or html
'''
V
=
self
.
get_V
()
D
=
self
.
get_D
()
## topics
dict_topics
=
self
.
topics
(
l
=
l
,
n
=-
1
)
list_topics
=
sorted
(
list
(
dict_topics
.
keys
()))
list_columns
=
[
'
Topic %s
'
%
(
t
+
1
)
for
t
in
list_topics
]
T
=
len
(
list_topics
)
df
=
pd
.
DataFrame
(
columns
=
list_columns
,
index
=
range
(
V
))
for
t
in
list_topics
:
list_w
=
[
h
[
0
]
for
h
in
dict_topics
[
t
]]
V_t
=
len
(
list_w
)
df
.
iloc
[:
V_t
,
t
]
=
list_w
df
=
df
.
dropna
(
how
=
'
all
'
,
axis
=
0
)
if
format
==
'
csv
'
:
fname_save
=
'
topsbm_level_%s_topics.csv
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_csv
(
filename
,
index
=
False
,
na_rep
=
''
)
elif
format
==
'
html
'
:
fname_save
=
'
topsbm_level_%s_topics.html
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_html
(
filename
,
index
=
False
,
na_rep
=
''
)
elif
format
==
'
tsv
'
:
fname_save
=
'
topsbm_level_%s_topics.tsv
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_csv
(
filename
,
index
=
False
,
na_rep
=
''
,
sep
=
'
\t
'
)
else
:
pass
## topic distributions
list_columns
=
[
'
i_doc
'
,
'
doc
'
]
+
[
'
Topic %s
'
%
(
t
+
1
)
for
t
in
list_topics
]
df
=
pd
.
DataFrame
(
columns
=
list_columns
,
index
=
range
(
D
))
for
i_doc
in
range
(
D
):
list_topicdist
=
self
.
topicdist
(
i_doc
,
l
=
l
)
df
.
iloc
[
i_doc
,
0
]
=
i_doc
df
.
iloc
[
i_doc
,
1
]
=
self
.
documents
[
i_doc
]
df
.
iloc
[
i_doc
,
2
:]
=
[
h
[
1
]
for
h
in
list_topicdist
]
df
=
df
.
dropna
(
how
=
'
all
'
,
axis
=
1
)
if
format
==
'
csv
'
:
fname_save
=
'
topsbm_level_%s_topic-dist.csv
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_csv
(
filename
,
index
=
False
,
na_rep
=
''
)
elif
format
==
'
html
'
:
fname_save
=
'
topsbm_level_%s_topic-dist.html
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_html
(
filename
,
index
=
False
,
na_rep
=
''
)
else
:
pass
## doc-groups
dict_clusters
=
self
.
clusters
(
l
=
l
,
n
=-
1
)
list_clusters
=
sorted
(
list
(
dict_clusters
.
keys
()))
list_columns
=
[
'
Cluster %s
'
%
(
t
+
1
)
for
t
in
list_clusters
]
T
=
len
(
list_clusters
)
df
=
pd
.
DataFrame
(
columns
=
list_columns
,
index
=
range
(
D
))
for
t
in
list_clusters
:
list_d
=
[
h
[
0
]
for
h
in
dict_clusters
[
t
]]
D_t
=
len
(
list_d
)
df
.
iloc
[:
D_t
,
t
]
=
list_d
df
=
df
.
dropna
(
how
=
'
all
'
,
axis
=
0
)
if
format
==
'
csv
'
:
fname_save
=
'
topsbm_level_%s_clusters.csv
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_csv
(
filename
,
index
=
False
,
na_rep
=
''
)
elif
format
==
'
html
'
:
fname_save
=
'
topsbm_level_%s_clusters.html
'
%
(
l
)
filename
=
os
.
path
.
join
(
path_save
,
fname_save
)
df
.
to_html
(
filename
,
index
=
False
,
na_rep
=
''
)
else
:
pass
###########
########### HELPER FUNCTIONS
###########
## get group-topic statistics
def
get_groups
(
self
,
l
=
0
):
'''
extract statistics on group membership of nodes form the inferred state.
return dictionary
- B_d, int, number of doc-groups
- B_w, int, number of word-groups
- p_tw_w, array B_w x V; word-group-membership:
prob that word-node w belongs to word-group tw: P(tw | w)
- p_td_d, array B_d x D; doc-group membership:
prob that doc-node d belongs to doc-group td: P(td | d)
- p_w_tw, array V x B_w; topic distribution:
prob of word w given topic tw P(w | tw)
- p_tw_d, array B_w x d; doc-topic mixtures:
prob of word-group tw in doc d P(tw | d)
'''
V
=
self
.
get_V
()
D
=
self
.
get_D
()
N
=
self
.
get_N
()
g
=
self
.
g
state
=
self
.
state
state_l
=
state
.
project_level
(
l
).
copy
(
overlap
=
True
)
state_l_edges
=
state_l
.
get_edge_blocks
()
## labeled half-edges
counts
=
'
count
'
in
self
.
g
.
ep
.
keys
()
## count labeled half-edges, group-memberships
B
=
state_l
.
get_B
()
n_wb
=
np
.
zeros
((
V
,
B
))
## number of half-edges incident on word-node w and labeled as word-group tw
n_db
=
np
.
zeros
((
D
,
B
))
## number of half-edges incident on document-node d and labeled as document-group td
n_dbw
=
np
.
zeros
((
D
,
B
))
## number of half-edges incident on document-node d and labeled as word-group td
for
e
in
g
.
edges
():
z1
,
z2
=
state_l_edges
[
e
]
v1
=
e
.
source
()
v2
=
e
.
target
()
if
counts
:
weight
=
g
.
ep
[
"
count
"
][
e
]
else
:
weight
=
1
n_db
[
int
(
v1
),
z1
]
+=
weight
n_dbw
[
int
(
v1
),
z2
]
+=
weight
n_wb
[
int
(
v2
)
-
D
,
z2
]
+=
weight
p_w
=
np
.
sum
(
n_wb
,
axis
=
1
)
/
float
(
np
.
sum
(
n_wb
))
ind_d
=
np
.
where
(
np
.
sum
(
n_db
,
axis
=
0
)
>
0
)[
0
]
Bd
=
len
(
ind_d
)
n_db
=
n_db
[:,
ind_d
]
ind_w
=
np
.
where
(
np
.
sum
(
n_wb
,
axis
=
0
)
>
0
)[
0
]
Bw
=
len
(
ind_w
)
n_wb
=
n_wb
[:,
ind_w
]
ind_w2
=
np
.
where
(
np
.
sum
(
n_dbw
,
axis
=
0
)
>
0
)[
0
]
n_dbw
=
n_dbw
[:,
ind_w2
]
## group-membership distributions
# group membership of each word-node P(t_w | w)
p_tw_w
=
(
n_wb
/
np
.
sum
(
n_wb
,
axis
=
1
)[:,
np
.
newaxis
]).
T
# group membership of each doc-node P(t_d | d)
p_td_d
=
(
n_db
/
np
.
sum
(
n_db
,
axis
=
1
)[:,
np
.
newaxis
]).
T
## topic-distribution for words P(w | t_w)
p_w_tw
=
n_wb
/
np
.
sum
(
n_wb
,
axis
=
0
)[
np
.
newaxis
,:]
## Mixture of word-groups into documetns P(t_w | d)
p_tw_d
=
(
n_dbw
/
np
.
sum
(
n_dbw
,
axis
=
1
)[:,
np
.
newaxis
]).
T
result
=
{}
result
[
'
Bd
'
]
=
Bd
result
[
'
Bw
'
]
=
Bw
result
[
'
p_tw_w
'
]
=
p_tw_w
result
[
'
p_td_d
'
]
=
p_td_d
result
[
'
p_w_tw
'
]
=
p_w_tw
result
[
'
p_tw_d
'
]
=
p_tw_d
return
result
### helper functions
def
get_V
(
self
):
'''
return number of word-nodes == types
'''
return
int
(
np
.
sum
(
self
.
g
.
vp
[
'
kind
'
].
a
==
1
))
# no. of types
def
get_D
(
self
):
'''
return number of doc-nodes == number of documents
'''
return
int
(
np
.
sum
(
self
.
g
.
vp
[
'
kind
'
].
a
==
0
))
# no. of types
def
get_N
(
self
):
'''
return number of edges == tokens
'''
return
int
(
self
.
g
.
num_edges
())
# no. of types
def
group_to_group_mixture
(
self
,
l
=
0
,
norm
=
True
):
V
=
self
.
get_V
()
D
=
self
.
get_D
()
N
=
self
.
get_N
()
g
=
self
.
g
state
=
self
.
state
state_l
=
state
.
project_level
(
l
).
copy
(
overlap
=
True
)
state_l_edges
=
state_l
.
get_edge_blocks
()
## labeled half-edges
## count labeled half-edges, group-memberships
B
=
state_l
.
get_B
()
n_td_tw
=
np
.
zeros
((
B
,
B
))
counts
=
'
count
'
in
self
.
g
.
ep
.
keys
()
for
e
in
g
.
edges
():
z1
,
z2
=
state_l_edges
[
e
]
if
counts
:
n_td_tw
[
z1
,
z2
]
+=
g
.
ep
[
"
count
"
][
e
]
else
:
n_td_tw
[
z1
,
z2
]
+=
1
ind_d
=
np
.
where
(
np
.
sum
(
n_td_tw
,
axis
=
1
)
>
0
)[
0
]
Bd
=
len
(
ind_d
)
ind_w
=
np
.
where
(
np
.
sum
(
n_td_tw
,
axis
=
0
)
>
0
)[
0
]
Bw
=
len
(
ind_w
)
n_td_tw
=
n_td_tw
[:
Bd
,
Bd
:]
if
norm
==
True
:
return
n_td_tw
/
np
.
sum
(
n_td_tw
)
else
:
return
n_td_tw
def
pmi_td_tw
(
self
,
l
=
0
):
'''
Point-wise mutual information between topic-groups and doc-groups, S(td,tw)
This is an array of shape Bd x Bw.
It corresponds to
S(td,tw) = log P(tw | td) /
\t
ilde{P}(tw | td) .
This is the log-ratio between
P(tw | td) == prb of topic tw in doc-group td;
\t
ilde{P}(tw | td) = P(tw) expected prob of topic tw in doc-group td under random null model.
'''
p_td_tw
=
self
.
group_to_group_mixture
(
l
=
l
)
p_tw_td
=
p_td_tw
.
T
p_td
=
np
.
sum
(
p_tw_td
,
axis
=
0
)
p_tw
=
np
.
sum
(
p_tw_td
,
axis
=
1
)
pmi_td_tw
=
np
.
log
(
p_tw_td
/
(
p_td
*
p_tw
[:,
np
.
newaxis
])).
T
return
pmi_td_tw
def
print_summary
(
self
,
tofile
=
True
):
'''
Print hierarchy summary
'''
if
tofile
:
orig_stdout
=
sys
.
stdout
f
=
open
(
'
summary.txt
'
,
'
w
'
)
sys
.
stdout
=
f
self
.
state
.
print_summary
()
sys
.
stdout
=
orig_stdout
f
.
close
()
else
:
self
.
state
.
print_summary
()
def
plot_topic_dist
(
self
,
l
):
groups
=
self
.
groups
[
l
]
p_w_tw
=
groups
[
'
p_w_tw
'
]
fig
=
plt
.
figure
(
figsize
=
(
12
,
10
))
plt
.
imshow
(
p_w_tw
,
origin
=
'
lower
'
,
aspect
=
'
auto
'
,
interpolation
=
'
none
'
)
plt
.
title
(
r
'
Word group membership $P(w | tw)$
'
)
plt
.
xlabel
(
'
Topic, tw
'
)
plt
.
ylabel
(
'
Word w (index)
'
)
plt
.
colorbar
()
fig
.
savefig
(
"
p_w_tw_%d.png
"
%
l
)
p_tw_d
=
groups
[
'
p_tw_d
'
]
fig
=
plt
.
figure
(
figsize
=
(
12
,
10
))
plt
.
imshow
(
p_tw_d
,
origin
=
'
lower
'
,
aspect
=
'
auto
'
,
interpolation
=
'
none
'
)
plt
.
title
(
r
'
Word group membership $P(tw | d)$
'
)
plt
.
xlabel
(
'
Document (index)
'
)
plt
.
ylabel
(
'
Topic, tw
'
)
plt
.
colorbar
()
fig
.
savefig
(
"
p_tw_d_%d.png
"
%
l
)
def
save_data
(
self
):
for
i
in
range
(
len
(
self
.
state
.
get_levels
())
-
2
)[::
-
1
]:
print
(
"
Saving level %d
"
%
i
)
self
.
print_topics
(
l
=
i
)
self
.
print_topics
(
l
=
i
,
format
=
'
tsv
'
)
self
.
plot_topic_dist
(
i
)
e
=
self
.
state
.
get_levels
()[
i
].
get_matrix
()
plt
.
matshow
(
e
.
todense
())
plt
.
savefig
(
"
mat_%d.png
"
%
i
)
self
.
print_summary
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment