Introduction

There have been many interesting applications of graph structure in text data for natural language processing tasks. For a brief overview check out my other article here. In this article, I want to go more in depth into one of the papers that had been mentioned: Graph Convolutional Networks for Text Classification by Yao et al.

Traditionally, models aimed towards text classification had been focused on the effectiveness of word embeddings and aggregated word embeddings for document embeddings. These word embeddings could be unsupervised pre-trained embeddings (think word2vec or Glove) which are then fed into a classifier. More recently, deep learning models such as CNNs and RNNs have emerged as useful encoders of text. In both these cases, text representations are learned from word embeddings. The approach by Yao et al. proposes to learn the word and document embeddings simultaneously for text classification.

At a high level, the model by Yao et al. called \textsc{TextGCN}, takes a given corpus of documents and words, and constructs a graph in which documents and words are nodes (details about edges discussed later). With this constructed graph, \textsc{TextGCN} utilizes a Graph Convolutional Network to learn better node representations (representations for the words and documents). These updated representations could then be fed into a classifier.

Graph Convolutional Network Prerequisites

Here I will give a really quick overview of Graph Convolutional Networks (GCN). For a more thorough introduction please check out this introduction on medium.

First let us define the input to a GCN, a graph. A graph \mathcal{G}=(V,E) is comprised of a node set V and edge set E. The edges connecting nodes can be represented by an adjacency matrix \textbf{A} \in \mathbb{R}^{ |V| \times|V|}. If \textbf{A}_{ij}, is positive, it means there is an edge between node i and node j with edge weight \textbf{A}_{ij}. GCN also adds self loops to its nodes so the adjacency matrix becomes

(1)   \begin{equation*} \hat{\textbf{A}} = \textbf{A} + \textbf{I} \end{equation*}

Additionally, each node can be represented by a vector, considered as the node’s features. Thus we also have a matrix \textbf{X} \in \mathbb{R}^{ |V| \times|D|}, in which \textbf{X}_i is the node feature for the i^{th} node. D_{in} is the dimension of the features. We also have a weight matrix \textbf{W}_i \in \mathbb{R}^{|D_{in}| \times |D_{out}|} where D_{out} is the output dimension. We finally introduce a diagonal degree matrix \textbf{D} \in \mathbb{R}^{|V| \times |V|}. Each diagonol entry is the degree of the corresponding node. Now we can describe one layer of GCN by the equation below. f(\cdot) is any activation function.

(2)   \begin{equation*} \textbf{X}_{t+1} \in \mathbb{R}^{|V| \times D_{out}} = f(\textbf{D}^{-1/2} \hat{\textbf{A}} \textbf{D}^{-1/2} \textbf{X}_t \textbf{W}_{t}) \end{equation*}

Let’s explain this equation step by step. Remember that matrix multiplication works from right to left. First \tilde{\textbf{X}_t} = \textbf{X}_i \textbf{W}_{t}  \in \mathbb{R}^{|V| \times D_{out}} takes the input node vector representations and multiplies it by some weights. Also note that \tilde{\textbf{A}} =\textbf{D}^{-1/2} \hat{\textbf{A}} \textbf{D}^{-1/2} \in  \mathbb{R}^{|V| \times |V|} is essentially the adjacency matrix but normalized by node degree. Thus we have two matrices multiplied together \tilde{\textbf{A}} \tilde{\textbf{X}_t}. This works out to a weighted sum of node representations of a node’s neighbors (each nonzero value in a column j of the i^{th} row of \tilde{\textbf{A}} means an edge from the node j to node i with a normalized weight of \tilde{\textbf{A}_{ij}}.

Thus, GCNs basically sums up the node representations of a node’s neighbor and the node itself (remember the self loop). K layers of GCN allows nodes to receive information from other nodes K hops away. GCNs fall under a class of Graph Neural Networks, called message passing networks in which messages (in this case the edge weight multiplied by the node representation) are being passed between neighbors. I like to think of these message passing networks as helping learn node representations that take account of its nearby neighbors as described by its graph structure. Thus, the way the graph is constructed, namely which edges are formed between which nodes, is very important. In the next section, I will discuss how the text graph in Graph Convolutional Networks for Text Classification is constructed.

Graph Construction

The details of constructing the “text” graph are as follows. First the total number of nodes is the number of documents d plus the number of unique words w. The node feature matrix is the identity matrix i.e \textbf{X} = \textbf{I}. Each node representation is thus a one-hot vector. Likewise, the adjacency matrix (the edges between document and word nodes) is defined as follows

These equations are taken from the paper. #W(i, j) is the number of sliding windows that contain both word i and word j and #W(i) is the number of sliding windows containing word i. #W is the total number of sliding windows.

PMI(i, j) is the point wise mutual information between two word nodes used to see how much two words co occur. The window size for calculating the co-occurrences is a hyperparameter of the model. In this paper, the authors set this to 20. Intuitively the graph construction attempts to place similar words and documents close to each other in the graph.

Model

After this construction, the authors simply run a two layer GCN followed by a softmax function for predicting the label. Formally the equation is

(3)   \begin{equation*} \textbf{Z}  = \mathrm{softmax}((\tilde{\textbf{A}}\mathrm{ReLU}(\tilde{\textbf{A}} \textbf{X} \textbf{W}_{0})\textbf{W}_{1}) \end{equation*}

For training loss, cross-entropy loss is used.

Experiments

The authors compare their model against CNN, LSTM variants and other baseline word or paragraph embedding models. The comparisons are performed on 5 datasets. Results are also shown below.

  1. 20NG over 18,000 documents evenly distributed in 20 categories
  2. Ohsumed over 7000 cardiovascular disease abstracts with 23 disease categories
  3. R52 a subset of Reuters-21578, documents that appeared on Reuters newswire in 1987. Roughly 10000 documents with 52 categories
  4. R8 same as above but with 7500 documents and 8 categories
  5. MR movie review dataset containing 10000 reviews and two categories: positive and negative sentiment

From the results, we can see that \textsc{TextGCN} performs the best or near the best for each dataset when compared to CNN, LSTM and other baselines. This performance purely comes from the edges and edge weights defined in the previous section as there are no relevant node features (remember each node is represented as an one hot embedding).

Learned Representations

To gain some insights into the learned representations the authors show a t-SNE visualization for the document embeddings obtained through \textsc{TextGCN}. We can see that even after one layer of GCN applied, the document embeddings are able to distinguish themselves decently well.

t-SNE visualization of document embeddings for the 20NG dataset

More concretely we can also see the results for the top 10 words for each class,using the embeddings from \textsc{TextGCN}. We can see that the model is able to predict related words for each category.

Closing Thoughts

I think this paper shows an interesting application of a simple GCN in the application of text classification and does show promising results. However, the model does have the limitation in that it is transductive ( a limitation of GCNs in general). During training the model sees every word and document in the dataset, including the test set. Although no predictions are made on the test set during training, this setting cannot be applied to say a completely new document. This leads to possible future work in how we can incorporate new documents into our already constructed graph. Overall, I think the paper shows the power of Graph Neural Networks and its applicability to any domain in which we can define and build some kind of useful graph structure.

Close Menu