Quellcode durchsuchen

Avoid overflow when shifting

The urts library and the signing tool often shift page counts as
32-bit integers, then passes the result as a 64-bit value. This patch
casts page counts into 64-bit integers first, so that large page
counts don't overflow.

Signed-off-by: Warren He <-w@berkeley.edu>
Warren He vor 7 Jahren
Ursprung
Commit
8bde48e653
2 geänderte Dateien mit 12 neuen und 12 gelöschten Zeilen
  1. 7 7
      psw/urts/loader.cpp
  2. 5 5
      sdk/sign_tool/SignTool/manage_metadata.cpp

+ 7 - 7
psw/urts/loader.cpp

@@ -265,14 +265,14 @@ int CLoader::build_context(const uint64_t start_rva, layout_entry_t *layout)
             ptcs->ogs_base += rva;
             m_tcs_list.push_back(GET_PTR(tcs_t, m_start_addr, rva));
             sinfo.flags = layout->si_flags;
-            if(SGX_SUCCESS != (ret = build_pages(rva, layout->page_count << SE_PAGE_SHIFT, added_page, sinfo, layout->attributes)))
+            if(SGX_SUCCESS != (ret = build_pages(rva, (uint64_t)layout->page_count << SE_PAGE_SHIFT, added_page, sinfo, layout->attributes)))
             {
                 return ret;
             }
         }
         else // guard page should not have content_offset != 0 
         {
-            section_info_t sec_info = {GET_PTR(uint8_t, m_metadata, layout->content_offset), layout->content_size, rva, layout->page_count << SE_PAGE_SHIFT, layout->si_flags, NULL};
+            section_info_t sec_info = {GET_PTR(uint8_t, m_metadata, layout->content_offset), layout->content_size, rva, (uint64_t)layout->page_count << SE_PAGE_SHIFT, layout->si_flags, NULL};
             if(SGX_SUCCESS != (ret = build_mem_region(&sec_info)))
             {
                 return ret;
@@ -292,7 +292,7 @@ int CLoader::build_context(const uint64_t start_rva, layout_entry_t *layout)
             }
             source = added_page;
         }
-        if(SGX_SUCCESS != (ret = build_pages(rva, layout->page_count << SE_PAGE_SHIFT, source, sinfo, layout->attributes)))
+        if(SGX_SUCCESS != (ret = build_pages(rva, (uint64_t)layout->page_count << SE_PAGE_SHIFT, source, sinfo, layout->attributes)))
         {
             return ret;
         }
@@ -438,7 +438,7 @@ int CLoader::validate_layout_table()
     {
         if(!IS_GROUP_ID(layout->entry.id))  // layout entry
         {
-            rva_vector.push_back(make_pair(layout->entry.rva, layout->entry.page_count << SE_PAGE_SHIFT));
+            rva_vector.push_back(make_pair(layout->entry.rva, (uint64_t)layout->entry.page_count << SE_PAGE_SHIFT));
             if(layout->entry.content_offset)
             {
                 if(false == is_metadata_buffer(layout->entry.content_offset, layout->entry.content_size))
@@ -467,7 +467,7 @@ int CLoader::validate_layout_table()
                     {
                         return SGX_ERROR_INVALID_METADATA;
                     }
-                    rva_vector.push_back(make_pair(entry->rva + load_step, entry->page_count << SE_PAGE_SHIFT));
+                    rva_vector.push_back(make_pair(entry->rva + load_step, (uint64_t)entry->page_count << SE_PAGE_SHIFT));
                     // no need to check integer overflow for entry->rva + load_step, because
                     // entry->rva and load_step are less than enclave_size, whose size is no more than 37 bit
                 }
@@ -751,13 +751,13 @@ int CLoader::set_context_protection(layout_t *layout_start, layout_t *layout_end
                 prot = SI_FLAGS_RW & SI_MASK_MEM_ATTRIBUTE;
             }
             ret = mprotect(GET_PTR(void, m_start_addr, layout->entry.rva + delta), 
-                               (size_t)(layout->entry.page_count << SE_PAGE_SHIFT),
+                               (size_t)layout->entry.page_count << SE_PAGE_SHIFT,
                                prot); 
             if(ret != 0)
             {
                 SE_TRACE(SE_TRACE_WARNING, "mprotect(rva=%" PRIu64 ", len=%" PRIu64 ", flags=%d) failed\n",
                          (uint64_t)m_start_addr + layout->entry.rva + delta, 
-                         (uint64_t)(layout->entry.page_count << SE_PAGE_SHIFT), 
+                         (uint64_t)layout->entry.page_count << SE_PAGE_SHIFT,
                           prot);
                 return SGX_ERROR_UNEXPECTED;
             }

+ 5 - 5
sdk/sign_tool/SignTool/manage_metadata.cpp

@@ -279,13 +279,13 @@ bool CMetadata::build_layout_entries(vector<layout_t> &layouts)
         if(!IS_GROUP_ID(layouts[i].entry.id))
         {
             layout_table->entry.rva = rva;
-            rva += (uint64_t)(layouts[i].entry.page_count << SE_PAGE_SHIFT);
+            rva += (uint64_t)layouts[i].entry.page_count << SE_PAGE_SHIFT;
         }
         else
         {
             for (uint32_t j = 0; j < layouts[i].group.entry_count; j++)
             {
-                layout_table->group.load_step += layouts[i-j-1].entry.page_count << SE_PAGE_SHIFT;
+                layout_table->group.load_step += (uint64_t)layouts[i-j-1].entry.page_count << SE_PAGE_SHIFT;
             }
             rva += layouts[i].group.load_times * layout_table->group.load_step;
         }
@@ -546,14 +546,14 @@ layout_entry_t *CMetadata::get_entry_by_id(uint16_t id)
 bool CMetadata::build_gd_template(uint8_t *data, uint32_t *data_size)
 {
     m_create_param.stack_limit_addr = get_entry_by_id(LAYOUT_ID_STACK)->rva - get_entry_by_id(LAYOUT_ID_TCS)->rva;
-    m_create_param.stack_base_addr = (get_entry_by_id(LAYOUT_ID_STACK)->page_count << SE_PAGE_SHIFT) + m_create_param.stack_limit_addr;
+    m_create_param.stack_base_addr = ((uint64_t)get_entry_by_id(LAYOUT_ID_STACK)->page_count << SE_PAGE_SHIFT) + m_create_param.stack_limit_addr;
     m_create_param.first_ssa_gpr = get_entry_by_id(LAYOUT_ID_SSA)->rva - get_entry_by_id(LAYOUT_ID_TCS)->rva
                                     + SSA_FRAME_SIZE * SE_PAGE_SIZE - (uint64_t)sizeof(ssa_gpr_t);
     m_create_param.enclave_size = m_metadata->enclave_size;
     m_create_param.heap_offset = get_entry_by_id(LAYOUT_ID_HEAP)->rva;
 
     uint64_t tmp_tls_addr = get_entry_by_id(LAYOUT_ID_TD)->rva - get_entry_by_id(LAYOUT_ID_TCS)->rva;
-    m_create_param.td_addr = tmp_tls_addr + ((get_entry_by_id(LAYOUT_ID_TD)->page_count - 1) << SE_PAGE_SHIFT);
+    m_create_param.td_addr = tmp_tls_addr + (((uint64_t)get_entry_by_id(LAYOUT_ID_TD)->page_count - 1) << SE_PAGE_SHIFT);
 
     const Section *section = m_parser->get_tls_section();
     if(section)
@@ -584,7 +584,7 @@ bool CMetadata::build_tcs_template(tcs_t *tcs)
     tcs->cssa = 0;
     tcs->ossa = get_entry_by_id(LAYOUT_ID_SSA)->rva - get_entry_by_id(LAYOUT_ID_TCS)->rva;
     //fs/gs pointer at TLS/TD
-    tcs->ofs_base = tcs->ogs_base = get_entry_by_id(LAYOUT_ID_TD)->rva - get_entry_by_id(LAYOUT_ID_TCS)->rva + (uint64_t)((get_entry_by_id(LAYOUT_ID_TD)->page_count - 1) << SE_PAGE_SHIFT);
+    tcs->ofs_base = tcs->ogs_base = get_entry_by_id(LAYOUT_ID_TD)->rva - get_entry_by_id(LAYOUT_ID_TCS)->rva + (((uint64_t)get_entry_by_id(LAYOUT_ID_TD)->page_count - 1) << SE_PAGE_SHIFT);
     tcs->ofs_limit = tcs->ogs_limit = (uint32_t)-1;
     return true;
 }