{"id":1437,"date":"2022-01-04T19:41:02","date_gmt":"2022-01-04T19:41:02","guid":{"rendered":"https:\/\/salarydistribution.com\/machine-learning\/2022\/01\/04\/train-graph-neural-nets-for-millions-of-proteins-on-amazon-sagemaker-and-amazon-documentdb-with-mongodb-compatibility\/"},"modified":"2022-01-04T19:41:02","modified_gmt":"2022-01-04T19:41:02","slug":"train-graph-neural-nets-for-millions-of-proteins-on-amazon-sagemaker-and-amazon-documentdb-with-mongodb-compatibility","status":"publish","type":"post","link":"https:\/\/salarydistribution.com\/machine-learning\/2022\/01\/04\/train-graph-neural-nets-for-millions-of-proteins-on-amazon-sagemaker-and-amazon-documentdb-with-mongodb-compatibility\/","title":{"rendered":"Train graph neural nets for millions of proteins on Amazon SageMaker and Amazon DocumentDB (with MongoDB compatibility)"},"content":{"rendered":"<div id=\"\">\n<p>There are over 180,000 unique proteins with 3D structures determined, with <a href=\"https:\/\/www.rcsb.org\/stats\/all-released-structures\" target=\"_blank\" rel=\"noopener noreferrer\">tens of thousands new structures resolved every year<\/a>. This is only a small fraction of the <a href=\"https:\/\/www.ebi.ac.uk\/uniprot\/TrEMBLstats\" target=\"_blank\" rel=\"noopener noreferrer\">200 million known proteins with distinctive sequences<\/a>. Recent deep learning algorithms such as <a href=\"https:\/\/www.nature.com\/articles\/s41586-021-03819-2\" target=\"_blank\" rel=\"noopener noreferrer\">AlphaFold<\/a> can accurately predict 3D structures of proteins using their sequences, which help scale the protein 3D structure data to the millions. Graph neural network (GNN) has emerged as an effective deep learning approach to extract information from protein structures, which can be represented by graphs of amino acid residues. Individual protein graphs usually contain a few hundred nodes, which is manageable in size. Tens of thousands of protein graphs can be easily stored in serialized data structures such as <a href=\"https:\/\/www.tensorflow.org\/tutorials\/load_data\/tfrecord\" target=\"_blank\" rel=\"noopener noreferrer\">TFrecord<\/a> for training GNNs. However, training GNN on millions of protein structures is challenging. Data serialization isn\u2019t scalable to millions of protein structures because it requires loading the entire terabyte-scale dataset into memory.<\/p>\n<p>In this post, we introduce a scalable deep learning solution that allows you to train GNNs on millions of proteins stored in <a href=\"https:\/\/aws.amazon.com\/documentdb\/\" target=\"_blank\" rel=\"noopener noreferrer\">Amazon DocumentDB (with MongoDB compatibility)<\/a> using <a href=\"https:\/\/aws.amazon.com\/sagemaker\/\" target=\"_blank\" rel=\"noopener noreferrer\">Amazon SageMaker<\/a>.<\/p>\n<p>For illustrative purposes, we use publicly available experimentally determined protein structures from the <a href=\"https:\/\/www.rcsb.org\/\" target=\"_blank\" rel=\"noopener noreferrer\">Protein Data Bank<\/a> and computationally predicted protein structures from the <a href=\"https:\/\/alphafold.ebi.ac.uk\/\" target=\"_blank\" rel=\"noopener noreferrer\">AlphaFold Protein Structure Database<\/a>. The machine learning (ML) problem is to develop a discriminator GNN model to distinguish between experimental and predicted structures based on protein graphs constructed from their 3D structures.<\/p>\n<h2>Overview of solution<\/h2>\n<p>We first parse the protein structures into JSON records with multiple types of data structures, such as an n-dimensional array and nested object, to store the proteins\u2019 atomic coordinates, properties, and identifiers. Storing a JSON record for a protein\u2019s structure takes 45 KB on average; we project storing 100 million proteins would take around 4.2 TB. Amazon DocumentDB storage <a href=\"https:\/\/docs.aws.amazon.com\/documentdb\/latest\/developerguide\/db-cluster-manage-performance.html#db-cluster-manage-scaling-storage\" target=\"_blank\" rel=\"noopener noreferrer\">automatically scales with the data<\/a> in your cluster volume in 10 GB increments, up to 64 TB. Therefore, the support for JSON data structure and scalability makes Amazon DocumentDB a natural choice.<\/p>\n<p>We next build a GNN model to predict protein properties using graphs of amino acid residues constructed from their structures. The GNN model is trained using SageMaker and configured to efficiently retrieve batches of protein structures from the database.<\/p>\n<p>Finally, we analyze the trained GNN model to gain some insights into the predictions.<\/p>\n<p>We walk through the following steps for this tutorial:<\/p>\n<ol>\n<li>Create resources using an <a href=\"http:\/\/aws.amazon.com\/cloudformation\" target=\"_blank\" rel=\"noopener noreferrer\">AWS CloudFormation<\/a> template.<\/li>\n<li>Prepare protein structures and properties and ingest the data into Amazon DocumentDB.<\/li>\n<li>Train a GNN on the protein structures using SageMaker.<\/li>\n<li>Load and evaluate the trained GNN model.<\/li>\n<\/ol>\n<p>The code and notebooks used in this post are available in the <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\" target=\"_blank\" rel=\"noopener noreferrer\">GitHub repo<\/a>.<\/p>\n<h2>Prerequisites<\/h2>\n<p>For this walkthrough, you should have the following prerequisites:<\/p>\n<p>Running this tutorial for an hour should cost no more than $2.00.<\/p>\n<h2>Create resources<\/h2>\n<p>We provide a <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\/blob\/main\/cloudformation.yaml\" target=\"_blank\" rel=\"noopener noreferrer\">CloudFormation template<\/a> to create the required AWS resources for this post, with a similar architecture as in the post <a href=\"https:\/\/aws.amazon.com\/blogs\/machine-learning\/analyzing-data-stored-in-amazon-documentdb-with-mongodb-compatibility-using-amazon-sagemaker\/\" target=\"_blank\" rel=\"noopener noreferrer\">Analyzing data stored in Amazon DocumentDB (with MongoDB compatibility) using Amazon SageMaker<\/a>. For instructions on creating a CloudFormation stack, see the video <a href=\"https:\/\/www.youtube.com\/watch?v=1h-GPXQrLZw&amp;feature=youtu.be&amp;t=153&amp;app=desktop\" target=\"_blank\" rel=\"noopener noreferrer\">Simplify your Infrastructure Management using AWS CloudFormation<\/a>.<\/p>\n<p>The CloudFormation stack provisions the following:<\/p>\n<ul>\n<li>A VPC with three private subnets for Amazon DocumentDB and two public subnets intended for the SageMaker notebook instance and ML training containers, respectively.<\/li>\n<li>An Amazon DocumentDB cluster with three nodes, one in each private subnet.<\/li>\n<li>A Secrets Manager secret to store login credentials for Amazon DocumentDB. This allows us to avoid storing plaintext credentials in our SageMaker instance.<\/li>\n<li>A SageMaker notebook instance to prepare data, orchestrate training jobs, and run interactive analyses.<\/li>\n<\/ul>\n<p>When creating the CloudFormation stack, you need to specify the following:<\/p>\n<ul>\n<li>Name for your CloudFormation stack<\/li>\n<li>Amazon DocumentDB user name and password (to be stored in Secrets Manager)<\/li>\n<li>Amazon DocumentDB instance type (default db.r5.large)<\/li>\n<li>SageMaker instance type (default ml.t3.xlarge)<\/li>\n<\/ul>\n<p>It should take about 15 minutes to create the CloudFormation stack. The following diagram shows the resource architecture.<\/p>\n<p><a href=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/ML-5542-image001.png\"><img decoding=\"async\" loading=\"lazy\" class=\"alignnone size-full wp-image-31646\" src=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/ML-5542-image001.png\" alt=\"\" width=\"1856\" height=\"1196\"><\/a><\/p>\n<h2>Prepare protein structures and properties and ingest the data into Amazon DocumentDB<\/h2>\n<p>All the subsequent code in this section is in the Jupyter notebook <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\/blob\/main\/Prepare_data.ipynb\" target=\"_blank\" rel=\"noopener noreferrer\">Prepare_data.ipynb<\/a> in the SageMaker instance created in your CloudFormation stack.<\/p>\n<p>This notebook handles the procedures required for preparing and ingesting protein structure data into Amazon DocumentDB.<\/p>\n<ol>\n<li>We first download predicted protein structures from <a href=\"https:\/\/alphafold.ebi.ac.uk\/download\" target=\"_blank\" rel=\"noopener noreferrer\">AlphaFold DB<\/a> in PDB format and the matching experimental structures from the <a href=\"https:\/\/www.rcsb.org\/\" target=\"_blank\" rel=\"noopener noreferrer\">Protein Data Bank<\/a>.<\/li>\n<\/ol>\n<p>For demonstration purposes, we only use proteins from the thermophilic archaean <a href=\"https:\/\/en.wikipedia.org\/wiki\/Methanocaldococcus_jannaschii\" target=\"_blank\" rel=\"noopener noreferrer\">Methanocaldococcus jannaschii<\/a>, which has the smallest proteome of 1,773 proteins for us to work with. You are welcome to try using proteins from other species.<\/p>\n<ol start=\"2\">\n<li>We connect to an Amazon DocumentDB cluster by retrieving the credentials stored in Secrets Manager:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">def get_secret(stack_name):\n\n    # Create a Secrets Manager client\n    session = boto3.session.Session()\n    client = session.client(\n        service_name=\"secretsmanager\",\n        region_name=session.region_name\n    )\n    \n    secret_name = f\"{stack_name}-DocDBSecret\"\n    get_secret_value_response = client.get_secret_value(SecretId=secret_name)\n    secret = get_secret_value_response[\"SecretString\"]\n    \nreturn json.loads(secret)\n\t\n\tsecrets = get_secret(\"gnn-proteins\")\n\t\n# connect to DocDB\n\turi = \"mongodb:\/\/{}:{}@{}:{}\/?tls=true&amp;tlsCAFile=rds-combined-ca-bundle.pem&amp;replicaSet=rs0&amp;readPreference=secondaryPreferred&amp;retryWrites=false\"\n    \t\t.format(secrets[\"username\"], secrets[\"password\"], secrets[\"host\"], secrets[\"port\"])\n\t\nclient = MongoClient(uri)\n\ndb = client[\"proteins\"] # create a database\ncollection = db[\"proteins\"] # create a collection<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"3\">\n<li>After we set up the connection to Amazon DocumentDB, we parse the PDB files into JSON records to ingest into the database.<\/li>\n<\/ol>\n<p>We provide utility functions required for parsing PDB files in <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\/blob\/main\/pdb_parse.py\" target=\"_blank\" rel=\"noopener noreferrer\">pdb_parse.py<\/a>. The <code>parse_pdb_file_to_json_record<\/code> function does the heavy lifting of extracting atomic coordinates from one or multiple peptide chains in a PDB file and returns one or a list of JSON documents, which can be directly ingested into the Amazon DocumentDB collection as a document. See the following code:<\/p>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">recs = parse_pdb_file_to_json_record(pdb_parser, pdb_file, pdb_id)\ncollection.insert_many(recs)<\/code><\/pre>\n<\/p><\/div>\n<p>After we ingest the parsed protein data into Amazon DocumentDB, we can update the contents of the protein documents. For instance, it makes our model training logistics easier if we add a field indicating whether a protein structure should be used in the training, validation, or test sets.<\/p>\n<ol start=\"4\">\n<li>We first retrieve the all the documents with the field <code>is_AF<\/code> to stratify documents using an aggregation pipeline:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">match = {\"is_AF\": {\"$exists\": True}}\nproject = {\"y\": \"$is_AF\"}\n\npipeline = [\n    {\"$match\": match},\n    {\"$project\": project},\n]\n# aggregation pipeline\ncur = collection.aggregate(pipeline)\n# retrieve documents from the DB cursor\ndocs = [doc for doc in cur]\n# convert to a data frame:\ndf = pd.DataFrame(docs)\n# stratified split: full -&gt; train\/test\ndf_train, df_test = train_test_split(\n    df, \n    test_size=0.2,\n    stratify=df[\"y\"], \n    random_state=42\n)\n# stratified split: train -&gt; train\/valid\ndf_train, df_valid = train_test_split(\n    df_train, \n    test_size=0.2,\n    stratify=df_train[\"y\"], \n    random_state=42\n)<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"5\">\n<li>Next, we use the <code>update_many<\/code> function to store the split information back to Amazon DocumentDB:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">for split, df_split in zip(\n    [\"train\", \"valid\", \"test\"],\n    [df_train, df_valid, df_test]\n):\n    result = collection.update_many(\n        {\"_id\": {\"$in\": df_split[\"_id\"].tolist()}}, \n        {\"$set\": {\"split\": split}}\n)\nprint(\"Number of documents modified:\", result.modified_count)<\/code><\/pre>\n<\/p><\/div>\n<h2>Train a GNN on the protein structures using SageMaker<\/h2>\n<p>All the subsequent code in this section is in the <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\/blob\/main\/Train_and_eval.ipynb\" target=\"_blank\" rel=\"noopener noreferrer\">Train_and_eval.ipynb<\/a> notebook in the SageMaker instance created in your CloudFormation stack.<\/p>\n<p>This notebook trains a GNN model on the protein structure datasets stored in the Amazon DocumentDB.<\/p>\n<p>We first need to implement a PyTorch dataset class for our protein dataset capable of retrieving mini-batches of protein documents from Amazon DocumentDB. It\u2019s more efficient to retrieve batches documents by the built-in primary id (<code>_id<\/code>).<\/p>\n<ol>\n<li>We use the iterable-style dataset by extending the <a href=\"https:\/\/pytorch.org\/docs\/stable\/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset\" target=\"_blank\" rel=\"noopener noreferrer\">IterableDataset<\/a>, which pre-fetches the <code>_id<\/code> and labels of the documents at initialization:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">class ProteinDataset(data.IterableDataset):\n    \"\"\"\n    An iterable-style dataset for proteins in DocumentDB\n    Args:\n        pipeline: an aggregation pipeline to retrieve data from DocumentDB\n        db_uri: URI of the DocumentDB\n        db_name: name of the database\n        collection_name: name of the collection\n        k: k used for kNN when creating a graph from atomic coordinates\n    \"\"\"\n\n    def __init__(\n        self, pipeline, db_uri=\"\", db_name=\"\", collection_name=\"\", k=3\n    ):\n\n        self.db_uri = db_uri\n        self.db_name = db_name\n        self.collection_name = collection_name\n        self.k = k\n\n        client = MongoClient(self.db_uri, connect=False)\n        collection = client[self.db_name][self.collection_name]\n        # pre-fetch the metadata as docs from DocumentDB\n        self.docs = [doc for doc in collection.aggregate(pipeline)]\n        # mapping document '_id' to label\n        self.labels = {doc[\"_id\"]: doc[\"y\"] for doc in self.docs}<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"2\">\n<li>The <code>ProteinDataset<\/code> performs a database read operation in the <code>__iter__<\/code> method. It tries to evenly split the workload if there are multiple workers:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">def __iter__(self):\n        worker_info = torch.utils.data.get_worker_info()\n        if worker_info is None:\n            # single-process data loading, return the full iterator\n            protein_ids = [doc[\"_id\"] for doc in self.docs]\n\n        else:  # in a worker process\n            # split workload\n            start = 0\n            end = len(self.docs)\n            per_worker = int(\n                math.ceil((end - start) \/ float(worker_info.num_workers))\n            )\n            worker_id = worker_info.id\n            iter_start = start + worker_id * per_worker\n            iter_end = min(iter_start + per_worker, end)\n\n            protein_ids = [\n                doc[\"_id\"] for doc in self.docs[iter_start:iter_end]\n            ]\n\n        # retrieve a list of proteins by _id from DocDB\n        with MongoClient(self.db_uri) as client:\n            collection = client[self.db_name][self.collection_name]\n            cur = collection.find(\n                {\"_id\": {\"$in\": protein_ids}},\n                projection={\"coords\": True, \"seq\": True},\n            )\n            return (\n                (\n                    convert_to_graph(protein, k=self.k),\n                    self.labels[protein[\"_id\"]],\n                )\n                for protein in cur\n            )<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"3\">\n<li>The preceding <code>__iter__<\/code> method also converts the atomic coordinates of proteins into <a href=\"https:\/\/docs.dgl.ai\/en\/0.6.x\/api\/python\/dgl.DGLGraph.html#dgl.DGLGraph\" target=\"_blank\" rel=\"noopener noreferrer\">DGLGraph<\/a> objects after they\u2019re loaded from Amazon DocumentDB via the <code>convert_to_graph<\/code> function. This function constructs a k-nearest neighbor (kNN) graph for the amino acid residues using the 3D coordinates of the C-alpha atoms and adds one-hot encoded node features to represent residue identities:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">def convert_to_graph(protein, k=3):\n    \"\"\"\n    Convert a protein (dict) to a dgl graph using kNN.\n    \"\"\"\n    coords = torch.tensor(protein[\"coords\"])\n    X_ca = coords[:, 1]\n    # construct knn graph from C-alpha coordinates\n    g = dgl.knn_graph(X_ca, k=k)\n    seq = protein[\"seq\"]\n    node_features = torch.tensor([d1_to_index[residue] for residue in seq])\n    node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(\n        dtype=torch.float\n    )\n\n    # add node features\n    g.ndata[\"h\"] = node_features\n    return g<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"4\">\n<li>With the <code>ProteinDataset<\/code> implemented, we can initialize instances for train, validation, and test datasets and wrap the training instance with <code>BufferedShuffleDataset<\/code> to enable shuffling.<\/li>\n<li>We further wrap them with <code>torch.utils.data.DataLoader<\/code> to work with other components of the <a href=\"https:\/\/sagemaker.readthedocs.io\/en\/stable\/frameworks\/pytorch\/sagemaker.pytorch.html\" target=\"_blank\" rel=\"noopener noreferrer\">SageMaker PyTorch Estimator<\/a> training script.<\/li>\n<li>Next, we implement a simple two-layered graph convolution network (GCN) with a global attention pooling layer for ease of interpretation:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">class GCN(nn.Module):\n    \"\"\"A two layer Graph Conv net with Global Attention Pooling over the\n    nodes.\n    Args:\n        in_feats: int, dim of input node features\n        h_feats: int, dim of hidden layers\n        num_classes: int, number of output units\n    \"\"\"\n\n    def __init__(self, in_feats, h_feats, num_classes):\n        super(GCN, self).__init__()\n        self.conv1 = GraphConv(in_feats, h_feats)\n        self.conv2 = GraphConv(h_feats, h_feats)\n        # the gate layer that maps node feature to outputs\n        self.gate_nn = nn.Linear(h_feats, num_classes)\n        self.gap = GlobalAttentionPooling(self.gate_nn)\n        # the output layer making predictions\n        self.output = nn.Linear(h_feats, num_classes)\n\n    def _conv_forward(self, g):\n        \"\"\"forward pass through the GraphConv layers\"\"\"\n        in_feat = g.ndata[\"h\"]\n        h = self.conv1(g, in_feat)\n        h = F.relu(h)\n        h = self.conv2(g, h)\n        h = F.relu(h)\n        return h\n\n    def forward(self, g):\n        h = self._conv_forward(g)\n        h = self.gap(g, h)\n        return self.output(h)\n\n    def attention_scores(self, g):\n        \"\"\"Calculate attention scores\"\"\"\n        h = self._conv_forward(g)\n        with g.local_scope():\n            gate = self.gap.gate_nn(h)\n            g.ndata[\"gate\"] = gate\n            gate = dgl.softmax_nodes(g, \"gate\")\n            g.ndata.pop(\"gate\")\n            return gate<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"7\">\n<li>Afterwards, we can train this GCN on the <code>ProteinDataset<\/code> instance for a binary classification task of predicting whether a protein structure is predicted by AlphaFold or not. We use binary cross entropy as the objective function and Adam optimizer for stochastic gradient optimization. The full training script can be found in <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\/blob\/main\/src\/main.py\" target=\"_blank\" rel=\"noopener noreferrer\">src\/main.py<\/a>.<\/li>\n<\/ol>\n<p>Next, we set up the SageMaker PyTorch Estimator to handle the training job. To allow the managed Docker container initiated by SageMaker to connect to Amazon DocumentDB, we need to configure the subnet and security group for the Estimator.<\/p>\n<ol start=\"8\">\n<li>We retrieve the subnet ID where the <a href=\"https:\/\/docs.aws.amazon.com\/vpc\/latest\/userguide\/vpc-nat-gateway.html\" target=\"_blank\" rel=\"noopener noreferrer\">Network Address Translation (NAT) gateway<\/a> resides, as well as the security group ID of our Amazon DocumentDB cluster by name:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-bash\">ec2 = boto3.client(\"ec2\")\n# find the NAT gateway's subnet ID \nresp = ec2.describe_subnets(\n    Filters=[{\"Name\": \"tag:Name\", \"Values\": [\"{}-NATSubnet\".format(stack_name)]}]\n)\nnat_subnet_id = resp[\"Subnets\"][0][\"SubnetId\"]\n# find security group id of the DocumentDB\nresp = ec2.describe_security_groups(\n    Filters=[{\n        \"Name\": \"tag:Name\", \n        \"Values\": [\"{}-SG-DocumentDB\".format(stack_name)]\n    }])\nsg_id = resp[\"SecurityGroups\"][0][\"GroupId\"]\nFinally, we can kick off the training of our GCN model using SageMaker: \nfrom sagemaker.pytorch import PyTorch\n\nCODE_PATH = \"main.py\"\n\nparams = {\n    \"patience\": 5, \n    \"n-epochs\": 200,\n    \"batch-size\": 64,\n    \"db-host\": secrets[\"host\"],\n    \"db-username\": secrets[\"username\"], \n    \"db-password\": secrets[\"password\"], \n    \"db-port\": secrets[\"port\"],\n    \"knn\": 4,\n}\n\nestimator = PyTorch(\n    entry_point=CODE_PATH,\n    source_dir=\"src\",\n    role=role,\n    instance_count=1,\n    instance_type=\"ml.p3.2xlarge\", # 'ml.c4.2xlarge' for CPU\n    framework_version=\"1.7.1\",\n    py_version=\"py3\",\n    hyperparameters=params,\n    sagemaker_session=sess,\n    subnets=[nat_subnet_id], \n    security_group_ids=[sg_id],\n)\n# run the training job:\nestimator.fit()<\/code><\/pre>\n<\/p><\/div>\n<h2>Load and evaluate the trained GNN model<\/h2>\n<p>When the training job is complete, we can load the trained GCN model and perform some in-depth evaluation.<\/p>\n<p>The codes for the following steps are also available in the notebook <a href=\"https:\/\/github.com\/aws-samples\/sagemaker-documentdb-train-gnn-for-millions-proteins\/blob\/main\/Train_and_eval.ipynb\" target=\"_blank\" rel=\"noopener noreferrer\">Train_and_eval.ipynb<\/a>.<\/p>\n<p>SageMaker training jobs save the model artifacts into the default S3 bucket, the URI of which can be accessed from the <code>estimator.model_data<\/code> attribute. We can also navigate to the <strong>Training jobs <\/strong>page on the SageMaker console to find the trained model to evaluate.<\/p>\n<ol>\n<li>For research purposes, we can load the model artifact (learned parameters) into a PyTorch <a href=\"https:\/\/pytorch.org\/tutorials\/recipes\/recipes\/what_is_state_dict.html\" target=\"_blank\" rel=\"noopener noreferrer\">state_dict<\/a> using the following function:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-python\">def load_sagemaker_model_artifact(s3_bucket, key):\n    \"\"\"Load a PyTorch model artifact (model.tar.gz) produced by a SageMaker\n    Training job.\n    Args:\n        s3_bucket: str, s3 bucket name (s3:\/\/bucket_name)\n        key: object key: path to model.tar.gz from within the bucket\n    Returns:\n        state_dict: dict representing the PyTorch checkpoint\n    \"\"\"\n    # load the s3 object\n    s3 = boto3.client(\"s3\")\n    obj = s3.get_object(Bucket=s3_bucket, Key=key)\n    # read into memory\n    model_artifact = BytesIO(obj[\"Body\"].read())\n    # parse out the state dict from the tar.gz file\n    tar = tarfile.open(fileobj=model_artifact)\n    for member in tar.getmembers():\n        pth = tar.extractfile(member).read()\n\n    state_dict = torch.load(BytesIO(pth), map_location=torch.device(\"cpu\"))\nreturn state_dict\n\n\tstate_dict = load_sagemaker_model_artifact(\nbucket, \nkey=estimator.model_data.split(bucket)[1].lstrip(\"\/\")\n)\n\n# initialize a GCN model\nmodel = GCN(dim_nfeats, 16, n_classes)\n# load the learned parameters\nmodel.load_state_dict(state_dict[\"model_state_dict\"])<\/code><\/pre>\n<\/p><\/div>\n<ol start=\"2\">\n<li>Next, we perform quantitative model evaluation on the full test set by calculating accuracy:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-python\">device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\nnum_correct = 0\nnum_tests = 0\nmodel.eval()\nwith torch.no_grad():\n    for batched_graph, labels in test_loader:\n        batched_graph = batched_graph.to(device)\n        labels = labels.to(device)\n        logits = model(batched_graph)\n        preds = (logits.sigmoid() &gt; 0.5).to(labels.dtype)\n        num_correct += (preds == labels).sum().item()\n        num_tests += len(labels)\n\nprint('Test accuracy: {:.6f}'.format(num_correct \/ num_tests))<\/code><\/pre>\n<\/p><\/div>\n<p>We found our GCN model achieved an accuracy of 74.3%, whereas the dummy baseline model making predictions based on class priors only achieved 56.3%.<\/p>\n<p>We\u2019re also interested in interpretability of our GCN model. Because we implement a global attention pooling layer, we can compute the attention scores across nodes to explain specific predictions made by the model.<\/p>\n<ol start=\"3\">\n<li>Next, we compute the attention scores and overlay them on the protein graphs for a pair of structures (AlphaFold predicted and experimental) from the same peptide:<\/li>\n<\/ol>\n<div class=\"hide-language\">\n<pre><code class=\"lang-python\">pair = [\"AF-Q57887\", \"1JT8-A\"]\ncur = collection.find(\n    {\"id\": {\"$in\": pair}},\n)\n\nfor doc in cur:\n    # convert to dgl.graph object\n    graph = convert_to_graph(doc, k=4)\n    \n    with torch.no_grad():\n        # make prediction\n        pred = model(graph).sigmoid()\n        # calculate attention scores for a protein graph\n        attn = model.attention_scores(graph)\n    \n    pred = pred.item()\n    attn = attn.numpy()\n    \n    # convert to networkx graph for visualization\n    graph = graph.to_networkx().to_undirected()\n    # calculate graph layout\n    pos = nx.spring_layout(graph, iterations=500)\n    \n    fig, ax = plt.subplots(figsize=(8, 8))\n    nx.draw(\n        graph, \n        pos, \n        node_color=attn.flatten(),\n        cmap=\"Reds\",\n        with_labels=True, \n        font_size=8,\n        ax=ax\n    )\n    ax.set(title=\"{}, p(is_predicted)={:.6f}\".format(doc[\"id\"], pred))\nplt.show()<\/code><\/pre>\n<\/p><\/div>\n<p>The preceding codes produce the following protein graphs overlaid with attention scores on the nodes. We find the model\u2019s global attentive pooling layer can highlight certain residues in the protein graph as being important for making the prediction of whether the protein structure is predicted by AlphaFold. This indicates that these residues may have distinctive graph topologies in predicted and experimental protein structures.<\/p>\n<table width=\"725\">\n<tbody>\n<tr>\n<td><a href=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/ML-5542-image005.png\"><img decoding=\"async\" loading=\"lazy\" class=\"alignnone size-full wp-image-31648\" src=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/ML-5542-image005.png\" alt=\"\" width=\"431\" height=\"450\"><\/a><\/td>\n<td><a href=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/ML-5542-image003.png\"><img decoding=\"async\" loading=\"lazy\" class=\"alignnone size-full wp-image-31647\" src=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/ML-5542-image003.png\" alt=\"\" width=\"449\" height=\"447\"><\/a><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p>In summary, we showcase a scalable deep learning solution to train GNNs on protein structures stored in Amazon DocumentDB. Although the tutorial only uses thousands of proteins for training, this solution is scalable to millions of proteins. Unlike other approaches such as serializing the entire protein dataset, our approach transfers the memory-heavy workloads to the database, making the memory complexity for the training jobs <code>O<\/code>(<code>batch_size<\/code>), which is independent of the total number of proteins to train.<\/p>\n<h2>Clean up<\/h2>\n<p>To avoid incurring future charges, delete the CloudFormation stack you created. This removes all the resources you provisioned using the CloudFormation template, including the VPC, Amazon DocumentDB cluster, and SageMaker instance. For instructions, see <a href=\"https:\/\/docs.aws.amazon.com\/AWSCloudFormation\/latest\/UserGuide\/cfn-console-delete-stack.html\" target=\"_blank\" rel=\"noopener noreferrer\">Deleting a stack on the AWS CloudFormation console<\/a>.<\/p>\n<h2>Conclusion<\/h2>\n<p>We described a cloud-based deep learning architecture scalable to millions of protein structures by storing them in Amazon DocumentDB and efficiently retrieving mini-batches of data from SageMaker.<\/p>\n<p>To learn more about the use of GNN in protein property predictions, check out our recent publication <a href=\"https:\/\/www.biorxiv.org\/content\/10.1101\/2021.09.21.460852v1\" target=\"_blank\" rel=\"noopener noreferrer\">LM-GVP, A Generalizable Deep Learning Framework for Protein Property Prediction from Sequence and Structure<\/a>.<\/p>\n<hr>\n<h3>About the Authors<\/h3>\n<p><a href=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/Zichen-Wang.jpg\"><img decoding=\"async\" loading=\"lazy\" class=\"size-full wp-image-31651 alignleft\" src=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/12\/09\/Zichen-Wang.jpg\" alt=\"\" width=\"100\" height=\"133\"><\/a><strong>Zichen Wang<\/strong>, PhD, is an Applied Scientist in the Amazon Machine Learning Solutions Lab. With several years of research experience in developing ML and statistical methods using biological and medical data, he works with customers across various verticals to solve their ML problems.<\/p>\n<p><a href=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/11\/03\/Selvan-Senthivel.png\"><img decoding=\"async\" loading=\"lazy\" class=\"size-full wp-image-30401 alignleft\" src=\"https:\/\/d2908q01vomqb2.cloudfront.net\/f1f836cb4ea6efb2a0b1b99f41ad8b103eff4b59\/2021\/11\/03\/Selvan-Senthivel.png\" alt=\"\" width=\"100\" height=\"134\"><\/a><strong>Selvan Senthivel<\/strong> is a Senior ML Engineer with the Amazon ML Solutions Lab at AWS, focusing on helping customers on machine learning, deep learning problems, and end-to-end ML solutions. He was a founding engineering lead of Amazon Comprehend Medical and contributed to the design and architecture of multiple AWS AI services.<\/p>\n<p>       <!-- '\"` -->\n      <\/div>\n","protected":false},"excerpt":{"rendered":"<p>https:\/\/aws.amazon.com\/blogs\/machine-learning\/train-graph-neural-nets-for-millions-of-proteins-on-amazon-sagemaker-and-amazon-documentdb-with-mongodb-compatibility\/<\/p>\n","protected":false},"author":0,"featured_media":1438,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":[],"categories":[3],"tags":[],"_links":{"self":[{"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/posts\/1437"}],"collection":[{"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/types\/post"}],"replies":[{"embeddable":true,"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/comments?post=1437"}],"version-history":[{"count":0,"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/posts\/1437\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/media\/1438"}],"wp:attachment":[{"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/media?parent=1437"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/categories?post=1437"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/salarydistribution.com\/machine-learning\/wp-json\/wp\/v2\/tags?post=1437"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}