CodeGen.ml 63 KB


  1. (*
  2. * Copyright (C) 2011-2017 Intel Corporation. All rights reserved.
  3. *
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions
  6. * are met:
  7. *
  8. * * Redistributions of source code must retain the above copyright
  9. * notice, this list of conditions and the following disclaimer.
  10. * * Redistributions in binary form must reproduce the above copyright
  11. * notice, this list of conditions and the following disclaimer in
  12. * the documentation and/or other materials provided with the
  13. * distribution.
  14. * * Neither the name of Intel Corporation nor the names of its
  15. * contributors may be used to endorse or promote products derived
  16. * from this software without specific prior written permission.
  17. *
  18. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  19. * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  20. * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  21. * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  22. * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  23. * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  24. * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  28. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *
  30. *)
  31. open Printf
  32. open Util (* for failwithf *)
  33. (* --------------------------------------------------------------------
  34. * We first introduce a `parse_enclave_ast' function (see below) to
  35. * parse a value of type `Ast.enclave' into a `enclave_content' record.
  36. * --------------------------------------------------------------------
  37. *)
  38. (* This record type is used to better organize a value of Ast.enclave *)
  39. type enclave_content = {
  40. file_shortnm : string; (* the short name of original EDL file *)
  41. enclave_name : string; (* the normalized C identifier *)
  42. include_list : string list;
  43. import_exprs : Ast.import_decl list;
  44. comp_defs : Ast.composite_type list;
  45. tfunc_decls : Ast.trusted_func list;
  46. ufunc_decls : Ast.untrusted_func list;
  47. }
  48. (* Whether to prefix untrusted proxy with Enclave name *)
  49. let g_use_prefix = ref false
  50. let g_untrusted_dir = ref "."
  51. let g_trusted_dir = ref "."
  52. let empty_ec =
  53. { file_shortnm = "";
  54. enclave_name = "";
  55. include_list = [];
  56. import_exprs = [];
  57. comp_defs = [];
  58. tfunc_decls = [];
  59. ufunc_decls = []; }
  60. let get_tf_fname (tf: Ast.trusted_func) =
  61. tf.Ast.tf_fdecl.Ast.fname
  62. let is_priv_ecall (tf: Ast.trusted_func) =
  63. tf.Ast.tf_is_priv
  64. let get_uf_fname (uf: Ast.untrusted_func) =
  65. uf.Ast.uf_fdecl.Ast.fname
  66. let get_trusted_func_names (ec: enclave_content) =
  67. List.map get_tf_fname ec.tfunc_decls
  68. let get_untrusted_func_names (ec: enclave_content) =
  69. List.map get_uf_fname ec.ufunc_decls
  70. let tf_list_to_fd_list (tfs: Ast.trusted_func list) =
  71. List.map (fun (tf: Ast.trusted_func) -> tf.Ast.tf_fdecl) tfs
  72. let tf_list_to_priv_list (tfs: Ast.trusted_func list) =
  73. List.map is_priv_ecall tfs
  74. (* Get a list of names of all private ECALLs *)
  75. let get_priv_ecall_names (tfs: Ast.trusted_func list) =
  76. List.filter is_priv_ecall tfs |> List.map get_tf_fname
  77. let uf_list_to_fd_list (ufs: Ast.untrusted_func list) =
  78. List.map (fun (uf: Ast.untrusted_func) -> uf.Ast.uf_fdecl) ufs
  79. (* Get a list of names of all allowed ECALLs from `allow(...)' *)
  80. let get_allowed_names (ufs: Ast.untrusted_func list) =
  81. let allow_lists =
  82. List.map (fun (uf: Ast.untrusted_func) -> uf.Ast.uf_allow_list) ufs
  83. in
  84. List.flatten allow_lists |> dedup_list
  85. (* With `parse_enclave_ast', each enclave AST is traversed only once. *)
  86. let parse_enclave_ast (e: Ast.enclave) =
  87. let ac_include_list = ref [] in
  88. let ac_import_exprs = ref [] in
  89. let ac_comp_defs = ref [] in
  90. let ac_tfunc_decls = ref [] in
  91. let ac_ufunc_decls = ref [] in
  92. List.iter (fun ex ->
  93. match ex with
  94. Ast.Composite x -> ac_comp_defs := x :: !ac_comp_defs
  95. | Ast.Include x -> ac_include_list := x :: !ac_include_list
  96. | Ast.Importing x -> ac_import_exprs := x :: !ac_import_exprs
  97. | Ast.Interface xs ->
  98. List.iter (fun ef ->
  99. match ef with
  100. Ast.Trusted f ->
  101. ac_tfunc_decls := f :: !ac_tfunc_decls
  102. | Ast.Untrusted f ->
  103. ac_ufunc_decls := f :: !ac_ufunc_decls) xs
  104. ) e.Ast.eexpr;
  105. { file_shortnm = e.Ast.ename;
  106. enclave_name = Util.to_c_identifier e.Ast.ename;
  107. include_list = List.rev !ac_include_list;
  108. import_exprs = List.rev !ac_import_exprs;
  109. comp_defs = List.rev !ac_comp_defs;
  110. tfunc_decls = List.rev !ac_tfunc_decls;
  111. ufunc_decls = List.rev !ac_ufunc_decls; }
  112. let is_foreign_array (pt: Ast.parameter_type) =
  113. match pt with
  114. Ast.PTVal _ -> false
  115. | Ast.PTPtr(t, a) ->
  116. match t with
  117. Ast.Foreign _ -> a.Ast.pa_isary
  118. | _ -> false
  119. (* A naked function has neither parameters nor return value. *)
  120. let is_naked_func (fd: Ast.func_decl) =
  121. fd.Ast.rtype = Ast.Void && fd.Ast.plist = []
  122. (*
  123. * If user only defined a trusted function w/o neither parameter nor
  124. * return value, the generated trusted bridge will not call any tRTS
  125. * routines. If the real trusted function doesn't call tRTS function
  126. * either (highly possible), then the MSVC linker will not link tRTS
  127. * into the result enclave.
  128. *)
  129. let tbridge_gen_dummy_variable (ec: enclave_content) =
  130. let _dummy_variable =
  131. sprintf "\n#ifdef _MSC_VER\n\
  132. \t/* In case enclave `%s' doesn't call any tRTS function. */\n\
  133. \tvolatile int force_link_trts = sgx_is_within_enclave(NULL, 0);\n\
  134. \t(void) force_link_trts; /* avoid compiler warning */\n\
  135. #endif\n\n" ec.enclave_name
  136. in
  137. if ec.ufunc_decls <> [] then ""
  138. else
  139. if List.for_all (fun tfd -> is_naked_func tfd.Ast.tf_fdecl) ec.tfunc_decls
  140. then _dummy_variable
  141. else ""
  142. (* This function is used to convert Array form into Pointer form.
  143. * e.g.: int array[10][20] => [count = 200] int* array
  144. *
  145. * This function is called when generating proxy/bridge code and
  146. * the marshaling structure.
  147. *)
  148. let conv_array_to_ptr (pd: Ast.pdecl): Ast.pdecl =
  149. let (pt, declr) = pd in
  150. let get_count_attr ilist =
  151. (* XXX: assume the size of each dimension will be > 0. *)
  152. Ast.ANumber (List.fold_left (fun acc i -> acc*i) 1 ilist)
  153. in
  154. match pt with
  155. Ast.PTVal _ -> (pt, declr)
  156. | Ast.PTPtr(aty, pa) ->
  157. if Ast.is_array declr then
  158. let tmp_declr = { declr with Ast.array_dims = [] } in
  159. let tmp_aty = Ast.Ptr aty in
  160. let tmp_cnt = get_count_attr declr.Ast.array_dims in
  161. let tmp_pa = { pa with Ast.pa_size = { Ast.empty_ptr_size with Ast.ps_count = Some tmp_cnt } }
  162. in (Ast.PTPtr(tmp_aty, tmp_pa), tmp_declr)
  163. else (pt, declr)
  164. (* ------------------------------------------------------------------
  165. * Code generation for edge-routines.
  166. * ------------------------------------------------------------------
  167. *)
  168. (* Little functions for naming of a struct and its members etc *)
  169. let retval_name = "retval"
  170. let retval_declr = { Ast.identifier = retval_name; Ast.array_dims = []; }
  171. let eid_name = "eid"
  172. let ms_ptr_name = "pms"
  173. let ms_struct_val = "ms"
  174. let mk_ms_member_name (pname: string) = "ms_" ^ pname
  175. let mk_ms_struct_name (fname: string) = "ms_" ^ fname ^ "_t"
  176. let ms_retval_name = mk_ms_member_name retval_name
  177. let mk_tbridge_name (fname: string) = "sgx_" ^ fname
  178. let mk_parm_accessor name = sprintf "%s->%s" ms_struct_val (mk_ms_member_name name)
  179. let mk_tmp_var name = "_tmp_" ^ name
  180. let mk_len_var name = "_len_" ^ name
  181. let mk_in_var name = "_in_" ^ name
  182. let mk_ocall_table_name enclave_name = "ocall_table_" ^ enclave_name
  183. (* Un-trusted bridge name is prefixed with enclave file short name. *)
  184. let mk_ubridge_name (file_shortnm: string) (funcname: string) =
  185. sprintf "%s_%s" file_shortnm funcname
  186. let mk_ubridge_proto (file_shortnm: string) (funcname: string) =
  187. sprintf "static sgx_status_t SGX_CDECL %s(void* %s)"
  188. (mk_ubridge_name file_shortnm funcname) ms_ptr_name
  189. (* Common macro definitions. *)
  190. let common_macros = "#include <stdlib.h> /* for size_t */\n\n\
  191. #define SGX_CAST(type, item) ((type)(item))\n\n\
  192. #ifdef __cplusplus\n\
  193. extern \"C\" {\n\
  194. #endif\n"
  195. (* Header footer *)
  196. let header_footer = "\n#ifdef __cplusplus\n}\n#endif /* __cplusplus */\n\n#endif\n"
  197. (* Little functions for generating file names. *)
  198. let get_uheader_short_name (file_shortnm: string) = file_shortnm ^ "_u.h"
  199. let get_uheader_name (file_shortnm: string) =
  200. !g_untrusted_dir ^ separator_str ^ (get_uheader_short_name file_shortnm)
  201. let get_usource_name (file_shortnm: string) =
  202. !g_untrusted_dir ^ separator_str ^ file_shortnm ^ "_u.c"
  203. let get_theader_short_name (file_shortnm: string) = file_shortnm ^ "_t.h"
  204. let get_theader_name (file_shortnm: string) =
  205. !g_trusted_dir ^ separator_str ^ (get_theader_short_name file_shortnm)
  206. let get_tsource_name (file_shortnm: string) =
  207. !g_trusted_dir ^ separator_str ^ file_shortnm ^ "_t.c"
  208. (* Construct the string of structure definition *)
  209. let mk_struct_decl (fs: string) (name: string) =
  210. sprintf "typedef struct %s {\n%s} %s;\n" name fs name
  211. (* Construct the string of union definition *)
  212. let mk_union_decl (fs: string) (name: string) =
  213. sprintf "typedef union %s {\n%s} %s;\n" name fs name
  214. (* Generate a definition of enum *)
  215. let mk_enum_def (e: Ast.enum_def) =
  216. let gen_enum_ele_str (ele: Ast.enum_ele) =
  217. let k, v = ele in
  218. match v with
  219. Ast.EnumValNone -> k
  220. | Ast.EnumVal ev -> sprintf "%s = %s" k (Ast.attr_value_to_string ev)
  221. in
  222. let enname = e.Ast.enname in
  223. let enbody = e.Ast.enbody in
  224. let enbody_str =
  225. if enbody = [] then ""
  226. else List.fold_left (fun acc ele ->
  227. acc ^ "\t" ^ gen_enum_ele_str ele ^ ",\n") "" enbody
  228. in
  229. if enname = "" then sprintf "enum {\n%s};\n" enbody_str
  230. else sprintf "typedef enum %s {\n%s} %s;\n" enname enbody_str enname
  231. let get_array_dims (ns: int list) =
  232. (* Get the array declaration from a list of array dimensions.
  233. * Empty `ns' indicates the corresponding declarator is a simple identifier.
  234. * Element of value -1 means that user does not specify the dimension size.
  235. *)
  236. let get_dim n = if n = -1 then "[]" else sprintf "[%d]" n
  237. in
  238. if ns = [] then ""
  239. else List.fold_left (fun acc n -> acc ^ get_dim n) "" ns
  240. let get_typed_declr_str (ty: Ast.atype) (declr: Ast.declarator) =
  241. let tystr = Ast.get_tystr ty in
  242. let dmstr = get_array_dims declr.Ast.array_dims in
  243. sprintf "%s %s%s" tystr declr.Ast.identifier dmstr
  244. (* Construct a member declaration string *)
  245. let mk_member_decl (ty: Ast.atype) (declr: Ast.declarator) =
  246. sprintf "\t%s;\n" (get_typed_declr_str ty declr)
  247. (* Note that, for a foreign array type `foo_array_t' we will generate
  248. * foo_array_t* ms_field;
  249. * in the marshaling data structure to keep the pass-by-address scheme
  250. * as in the C programming language.
  251. *)
  252. let mk_ms_member_decl (pt: Ast.parameter_type) (declr: Ast.declarator) =
  253. let aty = Ast.get_param_atype pt in
  254. let tystr = Ast.get_tystr aty in
  255. let ptr = if is_foreign_array pt then "* " else "" in
  256. let field = mk_ms_member_name declr.Ast.identifier in
  257. let dmstr = get_array_dims declr.Ast.array_dims in
  258. sprintf "\t%s%s %s%s;\n" tystr ptr field dmstr
  259. (* Generate data structure definition *)
  260. let gen_comp_def (st: Ast.composite_type) =
  261. let gen_member_list mlist =
  262. List.fold_left (fun acc (ty, declr) ->
  263. acc ^ mk_member_decl ty declr) "" mlist
  264. in
  265. match st with
  266. Ast.StructDef s -> mk_struct_decl (gen_member_list s.Ast.mlist) s.Ast.sname
  267. | Ast.UnionDef u -> mk_union_decl (gen_member_list u.Ast.mlist) u.Ast.sname
  268. | Ast.EnumDef e -> mk_enum_def e
  269. (* Generate a list of '#include' *)
  270. let gen_include_list (xs: string list) =
  271. List.fold_left (fun acc s -> acc ^ sprintf "#include \"%s\"\n" s) "" xs
  272. (* Get the type string from 'parameter_type' *)
  273. let get_param_tystr (pt: Ast.parameter_type) =
  274. Ast.get_tystr (Ast.get_param_atype pt)
  275. (* Generate marshaling structure definition *)
  276. let gen_marshal_struct (fd: Ast.func_decl) (errno: string) =
  277. let member_list_str = errno ^
  278. let new_param_list = List.map conv_array_to_ptr fd.Ast.plist in
  279. List.fold_left (fun acc (pt, declr) ->
  280. acc ^ mk_ms_member_decl pt declr) "" new_param_list in
  281. let struct_name = mk_ms_struct_name fd.Ast.fname in
  282. match fd.Ast.rtype with
  283. (* A function w/o return value and parameters doesn't need
  284. a marshaling struct. *)
  285. Ast.Void -> if fd.Ast.plist = [] && errno = "" then ""
  286. else mk_struct_decl member_list_str struct_name
  287. | _ -> let rv_str = mk_ms_member_decl (Ast.PTVal fd.Ast.rtype) retval_declr
  288. in mk_struct_decl (rv_str ^ member_list_str) struct_name
  289. let gen_ecall_marshal_struct (tf: Ast.trusted_func) =
  290. gen_marshal_struct tf.Ast.tf_fdecl ""
  291. let gen_ocall_marshal_struct (uf: Ast.untrusted_func) =
  292. let errno_decl = if uf.Ast.uf_propagate_errno then "\tint ocall_errno;\n" else "" in
  293. gen_marshal_struct uf.Ast.uf_fdecl errno_decl
  294. (* Check whether given parameter is `const' specified. *)
  295. let is_const_ptr (pt: Ast.parameter_type) =
  296. let aty = Ast.get_param_atype pt in
  297. match pt with
  298. Ast.PTVal _ -> false
  299. | Ast.PTPtr(_, pa) ->
  300. if not pa.Ast.pa_rdonly then false
  301. else
  302. match aty with
  303. Ast.Foreign _ -> false
  304. | _ -> true
  305. (* Generate parameter representation. *)
  306. let gen_parm_str (p: Ast.pdecl) =
  307. let (pt, (declr : Ast.declarator)) = p in
  308. let aty = Ast.get_param_atype pt in
  309. let str = get_typed_declr_str aty declr in
  310. if is_const_ptr pt then "const " ^ str else str
  311. (* Generate parameter representation of return value. *)
  312. let gen_parm_retval (rt: Ast.atype) =
  313. if rt = Ast.Void then ""
  314. else Ast.get_tystr rt ^ "* " ^ retval_name
  315. (* ---------------------------------------------------------------------- *)
  316. (* `gen_ecall_table' is used to generate ECALL table with the following form:
  317. SGX_EXTERNC const struct {
  318. size_t nr_ecall; /* number of ECALLs */
  319. struct {
  320. void *ecall_addr;
  321. uint8_t is_priv;
  322. } ecall_table [nr_ecall];
  323. } g_ecall_table = {
  324. 2, { {sgx_foo, 1}, {sgx_bar, 0} }
  325. };
  326. *)
  327. let gen_ecall_table (tfs: Ast.trusted_func list) =
  328. let ecall_table_name = "g_ecall_table" in
  329. let ecall_table_size = List.length tfs in
  330. let trusted_fds = tf_list_to_fd_list tfs in
  331. let priv_bits = tf_list_to_priv_list tfs in
  332. let tbridge_names = List.map (fun (fd: Ast.func_decl) ->
  333. mk_tbridge_name fd.Ast.fname) trusted_fds in
  334. let ecall_table =
  335. let bool_to_int b = if b then 1 else 0 in
  336. let inner_table =
  337. List.fold_left2 (fun acc s b ->
  338. sprintf "%s\t\t{(void*)(uintptr_t)%s, %d},\n" acc s (bool_to_int b)) "" tbridge_names priv_bits
  339. in "\t{\n" ^ inner_table ^ "\t}\n"
  340. in
  341. sprintf "SGX_EXTERNC const struct {\n\
  342. \tsize_t nr_ecall;\n\
  343. \tstruct {void* ecall_addr; uint8_t is_priv;} ecall_table[%d];\n\
  344. } %s = {\n\
  345. \t%d,\n\
  346. %s};\n" ecall_table_size
  347. ecall_table_name
  348. ecall_table_size
  349. (if ecall_table_size = 0 then "" else ecall_table)
  350. (* `gen_entry_table' is used to generate Dynamic Entry Table with the form:
  351. SGX_EXTERNC const struct {
  352. /* number of OCALLs (number of ECALLs can be found in ECALL table) */
  353. size_t nr_ocall;
  354. /* entry_table[m][n] = 1 iff. ECALL n is allowed in the OCALL m. */
  355. uint8_t entry_table[NR_OCALL][NR_ECALL];
  356. } g_dyn_entry_table = {
  357. 3, {{0, 0}, {0, 1}, {1, 0}}
  358. };
  359. *)
  360. let gen_entry_table (ec: enclave_content) =
  361. let dyn_entry_table_name = "g_dyn_entry_table" in
  362. let ocall_table_size = List.length ec.ufunc_decls in
  363. let trusted_func_names = get_trusted_func_names ec in
  364. let ecall_table_size = List.length trusted_func_names in
  365. let get_entry_array (allowed_ecalls: string list) =
  366. List.fold_left (fun acc name ->
  367. acc ^ (if List.exists (fun x -> x=name) allowed_ecalls
  368. then "1"
  369. else "0") ^ ", ") "" trusted_func_names in
  370. let entry_table =
  371. let inner_table =
  372. List.fold_left (fun acc (uf: Ast.untrusted_func) ->
  373. let entry_array = get_entry_array uf.Ast.uf_allow_list
  374. in acc ^ "\t\t{" ^ entry_array ^ "},\n") "" ec.ufunc_decls
  375. in
  376. "\t{\n" ^ inner_table ^ "\t}\n"
  377. in
  378. (* Generate dynamic entry table iff. both sgx_ecall/ocall_table_size > 0 *)
  379. let gen_table_p = (ecall_table_size > 0) && (ocall_table_size > 0) in
  380. (* When NR_ECALL is 0, or NR_OCALL is 0, there will be no entry table field. *)
  381. let entry_table_field =
  382. if gen_table_p then
  383. sprintf "\tuint8_t entry_table[%d][%d];\n" ocall_table_size ecall_table_size
  384. else
  385. ""
  386. in
  387. sprintf "SGX_EXTERNC const struct {\n\
  388. \tsize_t nr_ocall;\n%s\
  389. } %s = {\n\
  390. \t%d,\n\
  391. %s};\n" entry_table_field
  392. dyn_entry_table_name
  393. ocall_table_size
  394. (if gen_table_p then entry_table else "")
  395. (* ---------------------------------------------------------------------- *)
  396. (* Generate the function prototype for untrusted proxy in COM style.
  397. * For example, un-trusted functions
  398. * int foo(double d);
  399. * void bar(float f);
  400. *
  401. * will have an untrusted proxy like below:
  402. * sgx_status_t foo(int* retval, double d);
  403. * sgx_status_t bar(float f);
  404. *)
  405. let gen_tproxy_proto (fd: Ast.func_decl) =
  406. let parm_list =
  407. match fd.Ast.plist with
  408. [] -> ""
  409. | x :: xs ->
  410. List.fold_left (fun acc pd ->
  411. acc ^ ", " ^ gen_parm_str pd) (gen_parm_str x) xs
  412. in
  413. let retval_parm_str = gen_parm_retval fd.Ast.rtype in
  414. if fd.Ast.plist = [] then
  415. sprintf "sgx_status_t SGX_CDECL %s(%s)" fd.Ast.fname retval_parm_str
  416. else if fd.Ast.rtype = Ast.Void then
  417. sprintf "sgx_status_t SGX_CDECL %s(%s)" fd.Ast.fname parm_list
  418. else
  419. sprintf "sgx_status_t SGX_CDECL %s(%s, %s)" fd.Ast.fname retval_parm_str parm_list
  420. (* Generate the function prototype for untrusted proxy in COM style.
  421. * For example, trusted functions
  422. * int foo(double d);
  423. * void bar(float f);
  424. *
  425. * will have an untrusted proxy like below:
  426. * sgx_status_t foo(sgx_enclave_id_t eid, int* retval, double d);
  427. * sgx_status_t foo(sgx_enclave_id_t eid, float f);
  428. *
  429. * When `g_use_prefix' is true, the untrusted proxy name is prefixed
  430. * with the `prefix' parameter.
  431. *
  432. *)
  433. let gen_uproxy_com_proto (fd: Ast.func_decl) (prefix: string) =
  434. let retval_parm_str = gen_parm_retval fd.Ast.rtype in
  435. let eid_parm_str =
  436. if fd.Ast.rtype = Ast.Void then sprintf "(sgx_enclave_id_t %s" eid_name
  437. else sprintf "(sgx_enclave_id_t %s, " eid_name in
  438. let parm_list =
  439. List.fold_left (fun acc pd -> acc ^ ", " ^ gen_parm_str pd)
  440. retval_parm_str fd.Ast.plist in
  441. let fname =
  442. if !g_use_prefix then sprintf "%s_%s" prefix fd.Ast.fname
  443. else fd.Ast.fname
  444. in "sgx_status_t " ^ fname ^ eid_parm_str ^ parm_list ^ ")"
  445. let get_ret_tystr (fd: Ast.func_decl) = Ast.get_tystr fd.Ast.rtype
  446. let get_plist_str (fd: Ast.func_decl) =
  447. if fd.Ast.plist = [] then ""
  448. else List.fold_left (fun acc pd -> acc ^ ", " ^ gen_parm_str pd)
  449. (gen_parm_str (List.hd fd.Ast.plist))
  450. (List.tl fd.Ast.plist)
  451. (* Generate the function prototype as is. *)
  452. let gen_func_proto (fd: Ast.func_decl) =
  453. let ret_tystr = get_ret_tystr fd in
  454. let plist_str = get_plist_str fd in
  455. sprintf "%s %s(%s)" ret_tystr fd.Ast.fname plist_str
  456. (* Generate prototypes for untrusted function. *)
  457. let gen_ufunc_proto (uf: Ast.untrusted_func) =
  458. let dllimport = if uf.Ast.uf_fattr.Ast.fa_dllimport then "SGX_DLLIMPORT " else "" in
  459. let ret_tystr = get_ret_tystr uf.Ast.uf_fdecl in
  460. let cconv_str = "SGX_" ^ Ast.get_call_conv_str uf.Ast.uf_fattr.Ast.fa_convention in
  461. let func_name = uf.Ast.uf_fdecl.Ast.fname in
  462. let plist_str = get_plist_str uf.Ast.uf_fdecl in
  463. sprintf "%s%s SGX_UBRIDGE(%s, %s, (%s))"
  464. dllimport ret_tystr cconv_str func_name plist_str
  465. (* The preemble contains common include expressions. *)
  466. let gen_uheader_preemble (guard: string) (inclist: string)=
  467. let grd_hdr = sprintf "#ifndef %s\n#define %s\n\n" guard guard in
  468. let inc_exp = "#include <stdint.h>\n\
  469. #include <wchar.h>\n\
  470. #include <stddef.h>\n\
  471. #include <string.h>\n\
  472. #include \"sgx_edger8r.h\" /* for sgx_satus_t etc. */\n" in
  473. grd_hdr ^ inc_exp ^ "\n" ^ inclist ^ "\n" ^ common_macros
  474. let ms_writer out_chan ec =
  475. let ms_struct_ecall = List.map gen_ecall_marshal_struct ec.tfunc_decls in
  476. let ms_struct_ocall = List.map gen_ocall_marshal_struct ec.ufunc_decls in
  477. List.iter (fun s -> output_string out_chan (s ^ "\n")) ms_struct_ecall;
  478. List.iter (fun s -> output_string out_chan (s ^ "\n")) ms_struct_ocall
  479. (* Generate untrusted header for enclave *)
  480. let gen_untrusted_header (ec: enclave_content) =
  481. let header_fname = get_uheader_name ec.file_shortnm in
  482. let guard_macro = sprintf "%s_U_H__" (String.uppercase ec.enclave_name) in
  483. let preemble_code =
  484. let include_list = gen_include_list (ec.include_list @ !untrusted_headers) in
  485. gen_uheader_preemble guard_macro include_list
  486. in
  487. let comp_def_list = List.map gen_comp_def ec.comp_defs in
  488. let func_proto_ufunc = List.map gen_ufunc_proto ec.ufunc_decls in
  489. let uproxy_com_proto =
  490. List.map (fun (tf: Ast.trusted_func) ->
  491. gen_uproxy_com_proto tf.Ast.tf_fdecl ec.enclave_name)
  492. ec.tfunc_decls
  493. in
  494. let out_chan = open_out header_fname in
  495. output_string out_chan (preemble_code ^ "\n");
  496. List.iter (fun s -> output_string out_chan (s ^ "\n")) comp_def_list;
  497. List.iter (fun s -> output_string out_chan (s ^ ";\n")) func_proto_ufunc;
  498. output_string out_chan "\n";
  499. List.iter (fun s -> output_string out_chan (s ^ ";\n")) uproxy_com_proto;
  500. output_string out_chan header_footer;
  501. close_out out_chan
  502. (* It generates preemble for trusted header file. *)
  503. let gen_theader_preemble (guard: string) (inclist: string) =
  504. let grd_hdr = sprintf "#ifndef %s\n#define %s\n\n" guard guard in
  505. let inc_exp = "#include <stdint.h>\n\
  506. #include <wchar.h>\n\
  507. #include <stddef.h>\n\
  508. #include \"sgx_edger8r.h\" /* for sgx_ocall etc. */\n\n" in
  509. grd_hdr ^ inc_exp ^ inclist ^ "\n" ^ common_macros
  510. (* Generate function prototype for functions used by `sizefunc' attribute. *)
  511. let gen_sizefunc_proto out_chan (ec: enclave_content) =
  512. let tfunc_decls = tf_list_to_fd_list ec.tfunc_decls in
  513. let ufunc_decls = uf_list_to_fd_list ec.ufunc_decls in
  514. let dict = Hashtbl.create 4 in
  515. let get_sizefunc_proto s =
  516. let (pt, ns) = Hashtbl.find dict s in
  517. let tmpdeclr = { Ast.identifier = "val"; Ast.array_dims = ns; } in
  518. sprintf "size_t %s(const %s);\n" s (get_typed_declr_str pt tmpdeclr)
  519. in
  520. let add_item (fname: string) (ty: Ast.atype * int list) =
  521. try
  522. let v = Hashtbl.find dict fname
  523. in
  524. if v <> ty then
  525. failwithf "`%s' requires different parameter types" fname
  526. with Not_found -> Hashtbl.add dict fname ty
  527. in
  528. let fill_dict (pd: Ast.pdecl) =
  529. let (pt, declr) = pd in
  530. match pt with
  531. Ast.PTVal _ -> ()
  532. | Ast.PTPtr(aty, pattr) ->
  533. match pattr.Ast.pa_size.Ast.ps_sizefunc with
  534. Some s -> add_item s (aty, declr.Ast.array_dims)
  535. | _ -> ()
  536. in
  537. List.iter (fun (fd: Ast.func_decl) ->
  538. List.iter fill_dict fd.Ast.plist) (tfunc_decls @ ufunc_decls);
  539. Hashtbl.iter (fun x y ->
  540. output_string out_chan (get_sizefunc_proto x)) dict;
  541. output_string out_chan "\n"
  542. (* Generate trusted header for enclave *)
  543. let gen_trusted_header (ec: enclave_content) =
  544. let header_fname = get_theader_name ec.file_shortnm in
  545. let guard_macro = sprintf "%s_T_H__" (String.uppercase ec.enclave_name) in
  546. let guard_code =
  547. let include_list = gen_include_list (ec.include_list @ !trusted_headers) in
  548. gen_theader_preemble guard_macro include_list in
  549. let comp_def_list = List.map gen_comp_def ec.comp_defs in
  550. let func_proto_list = List.map gen_func_proto (tf_list_to_fd_list ec.tfunc_decls) in
  551. let func_tproxy_list= List.map gen_tproxy_proto (uf_list_to_fd_list ec.ufunc_decls) in
  552. let out_chan = open_out header_fname in
  553. output_string out_chan (guard_code ^ "\n");
  554. List.iter (fun s -> output_string out_chan (s ^ "\n")) comp_def_list;
  555. gen_sizefunc_proto out_chan ec;
  556. List.iter (fun s -> output_string out_chan (s ^ ";\n")) func_proto_list;
  557. output_string out_chan "\n";
  558. List.iter (fun s -> output_string out_chan (s ^ ";\n")) func_tproxy_list;
  559. output_string out_chan header_footer;
  560. close_out out_chan
  561. (* It generates function invocation expression. *)
  562. let mk_parm_name_raw (pt: Ast.parameter_type) (declr: Ast.declarator) =
  563. let cast_expr =
  564. if Ast.is_array declr && List.length declr.Ast.array_dims > 1
  565. then
  566. let tystr = get_param_tystr pt in
  567. let dims = get_array_dims (List.tl declr.Ast.array_dims) in
  568. sprintf "(%s (*)%s)" tystr dims
  569. else ""
  570. in
  571. cast_expr ^ mk_parm_accessor declr.Ast.identifier
  572. (* We passed foreign array `foo_array_t foo' as `&foo[0]', thus we
  573. * need to get back `foo' by '* array_ptr' where
  574. * array_ptr = &foo[0]
  575. *)
  576. let add_foreign_array_ptrref
  577. (f: Ast.parameter_type -> Ast.declarator -> string)
  578. (pt: Ast.parameter_type)
  579. (declr: Ast.declarator) =
  580. let arg = f pt declr in
  581. if is_foreign_array pt
  582. then sprintf "(%s != NULL) ? (*%s) : NULL" arg arg
  583. else arg
  584. let mk_parm_name_ubridge (pt: Ast.parameter_type) (declr: Ast.declarator) =
  585. add_foreign_array_ptrref mk_parm_name_raw pt declr
  586. let mk_parm_name_ext (pt: Ast.parameter_type) (declr: Ast.declarator) =
  587. let name = declr.Ast.identifier in
  588. match pt with
  589. Ast.PTVal _ -> mk_parm_name_raw pt declr
  590. | Ast.PTPtr (_, attr) ->
  591. match attr.Ast.pa_direction with
  592. | Ast.PtrNoDirection -> mk_parm_name_raw pt declr
  593. | _ -> mk_in_var name
  594. let gen_func_invoking (fd: Ast.func_decl)
  595. (mk_parm_name: Ast.parameter_type -> Ast.declarator -> string) =
  596. let gen_parm_str pt declr =
  597. let parm_name = mk_parm_name pt declr in
  598. let tystr = get_param_tystr pt in
  599. if is_const_ptr pt then sprintf "(const %s)%s" tystr parm_name else parm_name
  600. in
  601. match fd.Ast.plist with
  602. [] -> sprintf "%s();" fd.Ast.fname
  603. | (pt, (declr : Ast.declarator)) :: ps ->
  604. sprintf "%s(%s);"
  605. fd.Ast.fname
  606. (let p0 = gen_parm_str pt declr in
  607. List.fold_left (fun acc (pty, dlr) ->
  608. acc ^ ", " ^ gen_parm_str pty dlr) p0 ps)
  609. (* Generate untrusted bridge code for a given untrusted function. *)
  610. let gen_func_ubridge (file_shortnm: string) (ufunc: Ast.untrusted_func) =
  611. let fd = ufunc.Ast.uf_fdecl in
  612. let propagate_errno = ufunc.Ast.uf_propagate_errno in
  613. let func_open = sprintf "%s\n{\n" (mk_ubridge_proto file_shortnm fd.Ast.fname) in
  614. let func_close = "\treturn SGX_SUCCESS;\n}\n" in
  615. let set_errno = if propagate_errno then "\tms->ocall_errno = errno;" else "" in
  616. let ms_struct_name = mk_ms_struct_name fd.Ast.fname in
  617. let declare_ms_ptr = sprintf "%s* %s = SGX_CAST(%s*, %s);"
  618. ms_struct_name
  619. ms_struct_val
  620. ms_struct_name
  621. ms_ptr_name in
  622. let call_with_pms =
  623. let invoke_func = gen_func_invoking fd mk_parm_name_ubridge in
  624. if fd.Ast.rtype = Ast.Void then invoke_func
  625. else sprintf "%s = %s" (mk_parm_accessor retval_name) invoke_func
  626. in
  627. if (is_naked_func fd) && (propagate_errno = false) then
  628. let check_pms =
  629. sprintf "if (%s != NULL) return SGX_ERROR_INVALID_PARAMETER;" ms_ptr_name
  630. in
  631. sprintf "%s\t%s\n\t%s\n%s" func_open check_pms call_with_pms func_close
  632. else
  633. sprintf "%s\t%s\n\t%s\n%s\n%s" func_open declare_ms_ptr call_with_pms set_errno func_close
  634. let fill_ms_field (isptr: bool) (pd: Ast.pdecl) =
  635. let accessor = if isptr then "->" else "." in
  636. let (pt, declr) = pd in
  637. let param_name = declr.Ast.identifier in
  638. let ms_member_name = mk_ms_member_name param_name in
  639. let assignment_str (use_cast: bool) (aty: Ast.atype) =
  640. let cast_str = if use_cast then sprintf "(%s)" (Ast.get_tystr aty) else ""
  641. in
  642. sprintf "%s%s%s = %s%s;" ms_struct_val accessor ms_member_name cast_str param_name
  643. in
  644. let gen_setup_foreign_array aty =
  645. sprintf "%s%s%s = (%s *)&%s[0];"
  646. ms_struct_val accessor ms_member_name (Ast.get_tystr aty) param_name
  647. in
  648. if declr.Ast.array_dims = [] then
  649. match pt with
  650. Ast.PTVal(aty) -> assignment_str false aty
  651. | Ast.PTPtr(aty, pattr) ->
  652. if pattr.Ast.pa_isary
  653. then gen_setup_foreign_array aty
  654. else
  655. if pattr.Ast.pa_rdonly then assignment_str true aty
  656. else assignment_str false aty
  657. else
  658. (* Arrays are passed by address. *)
  659. let tystr = Ast.get_tystr (Ast.Ptr (Ast.get_param_atype pt)) in
  660. sprintf "%s%s%s = (%s)%s;" ms_struct_val accessor ms_member_name tystr param_name
  661. (* Generate untrusted proxy code for a given trusted function. *)
  662. let gen_func_uproxy (fd: Ast.func_decl) (idx: int) (ec: enclave_content) =
  663. let func_open =
  664. gen_uproxy_com_proto fd ec.enclave_name ^
  665. "\n{\n\tsgx_status_t status;\n"
  666. in
  667. let func_close = "\treturn status;\n}\n" in
  668. let ocall_table_name = mk_ocall_table_name ec.enclave_name in
  669. let ms_struct_name = mk_ms_struct_name fd.Ast.fname in
  670. let declare_ms_expr = sprintf "%s %s;" ms_struct_name ms_struct_val in
  671. let ocall_table_ptr =
  672. sprintf "&%s" ocall_table_name in
  673. (* Normal case - do ECALL with marshaling structure*)
  674. let ecall_with_ms = sprintf "status = sgx_ecall(%s, %d, %s, &%s);"
  675. eid_name idx ocall_table_ptr ms_struct_val in
  676. (* Rare case - the trusted function doesn't have parameter nor return value.
  677. * In this situation, no marshaling structure is required - passing in NULL.
  678. *)
  679. let ecall_null = sprintf "status = sgx_ecall(%s, %d, %s, NULL);"
  680. eid_name idx ocall_table_ptr
  681. in
  682. let update_retval = sprintf "if (status == SGX_SUCCESS && %s) *%s = %s.%s;"
  683. retval_name retval_name ms_struct_val ms_retval_name in
  684. let func_body = ref [] in
  685. if is_naked_func fd then
  686. sprintf "%s\t%s\n%s" func_open ecall_null func_close
  687. else
  688. begin
  689. func_body := declare_ms_expr :: !func_body;
  690. List.iter (fun pd -> func_body := fill_ms_field false pd :: !func_body) fd.Ast.plist;
  691. func_body := ecall_with_ms :: !func_body;
  692. if fd.Ast.rtype <> Ast.Void then func_body := update_retval :: !func_body;
  693. List.fold_left (fun acc s -> acc ^ "\t" ^ s ^ "\n") func_open (List.rev !func_body) ^ func_close
  694. end
  695. (* Generate an expression to check the pointers. *)
  696. let mk_check_ptr (name: string) (lenvar: string) =
  697. let checker = "CHECK_UNIQUE_POINTER"
  698. in sprintf "\t%s(%s, %s);\n" checker name lenvar
  699. (* Pointer to marshaling structure should never be NULL. *)
  700. let mk_check_pms (fname: string) =
  701. let lenvar = sprintf "sizeof(%s)" (mk_ms_struct_name fname)
  702. in sprintf "\t%s(%s, %s);\n" "CHECK_REF_POINTER" ms_ptr_name lenvar
  703. (* Generate code to get the size of the pointer. *)
  704. let gen_ptr_size (ty: Ast.atype) (pattr: Ast.ptr_attr) (name: string) (get_parm: string -> string) =
  705. let len_var = mk_len_var name in
  706. let parm_name = get_parm name in
  707. let mk_len_size v =
  708. match v with
  709. Ast.AString s -> get_parm s
  710. | Ast.ANumber n -> sprintf "%d" n in
  711. let mk_len_count v size_str =
  712. match v with
  713. Ast.AString s -> sprintf "%s * %s" (get_parm s) size_str
  714. | Ast.ANumber n -> sprintf "%d * %s" n size_str in
  715. let mk_len_sizefunc s = sprintf "((%s) ? %s(%s) : 0)" parm_name s parm_name in
  716. (* Note, during the parsing stage, we already eliminated the case that
  717. * user specified both 'size' and 'sizefunc' attribute.
  718. *)
  719. let do_attribute (pattr: Ast.ptr_attr) =
  720. let do_ps_attribute (sattr: Ast.ptr_size) =
  721. let size_str =
  722. match sattr.Ast.ps_size with
  723. Some a -> mk_len_size a
  724. | None ->
  725. match sattr.Ast.ps_sizefunc with
  726. None -> sprintf "sizeof(*%s)" parm_name
  727. | Some a -> mk_len_sizefunc a
  728. in
  729. match sattr.Ast.ps_count with
  730. None -> size_str
  731. | Some a -> mk_len_count a size_str
  732. in
  733. if pattr.Ast.pa_isstr then
  734. sprintf "%s ? strlen(%s) + 1 : 0" parm_name parm_name
  735. else if pattr.Ast.pa_iswstr then
  736. sprintf "%s ? (wcslen(%s) + 1) * sizeof(wchar_t) : 0" parm_name parm_name
  737. else
  738. do_ps_attribute pattr.Ast.pa_size
  739. in
  740. sprintf "size_t %s = %s;\n"
  741. len_var
  742. (if pattr.Ast.pa_isary
  743. then sprintf "sizeof(%s)" (Ast.get_tystr ty)
  744. else do_attribute pattr)
  745. (* Find the data type of a parameter. *)
  746. let find_param_type (name: string) (plist: Ast.pdecl list) =
  747. try
  748. let (pt, _) = List.find (fun (pd: Ast.pdecl) ->
  749. let (pt, declr) = pd
  750. in declr.Ast.identifier = name) plist
  751. in get_param_tystr pt
  752. with
  753. Not_found -> failwithf "parameter `%s' not found." name
  754. (* Generate code to check the length of buffers. *)
  755. let gen_check_tbridge_length_overflow (plist: Ast.pdecl list) =
  756. let gen_check_length (ty: Ast.atype) (attr: Ast.ptr_attr) (declr: Ast.declarator) =
  757. let name = declr.Ast.identifier in
  758. let tmp_ptr_name= mk_tmp_var name in
  759. let mk_len_size v =
  760. match v with
  761. Ast.AString s -> mk_tmp_var s
  762. | Ast.ANumber n -> sprintf "%d" n in
  763. let mk_len_sizefunc s = sprintf "((%s) ? %s(%s) : 0)" tmp_ptr_name s tmp_ptr_name in
  764. let gen_check_overflow cnt size_str =
  765. let if_statement =
  766. match cnt with
  767. Ast.AString s -> sprintf "\tif ((size_t)%s > (SIZE_MAX / %s)) {\n" (mk_tmp_var s) size_str
  768. | Ast.ANumber n -> sprintf "\tif (%d > (SIZE_MAX / %s)) {\n" n size_str
  769. in
  770. sprintf "%s\t\tstatus = SGX_ERROR_INVALID_PARAMETER;\n\t\tgoto err;\n\t}" if_statement
  771. in
  772. let size_str =
  773. match attr.Ast.pa_size.Ast.ps_size with
  774. Some a -> mk_len_size a
  775. | None ->
  776. match attr.Ast.pa_size.Ast.ps_sizefunc with
  777. None -> sprintf "sizeof(*%s)" tmp_ptr_name
  778. | Some a -> mk_len_sizefunc a
  779. in
  780. match attr.Ast.pa_size.Ast.ps_count with
  781. None -> ""
  782. | Some a -> sprintf "%s\n\n" (gen_check_overflow a size_str)
  783. in
  784. List.fold_left
  785. (fun acc (pty, declr) ->
  786. match pty with
  787. Ast.PTVal _ -> acc
  788. | Ast.PTPtr(ty, attr) -> acc ^ gen_check_length ty attr declr) "" plist
  789. (* Generate code to check all function parameters which are pointers. *)
  790. let gen_check_tbridge_ptr_parms (plist: Ast.pdecl list) =
  791. let gen_check_ptr (ty: Ast.atype) (pattr: Ast.ptr_attr) (declr: Ast.declarator) =
  792. if not pattr.Ast.pa_chkptr then ""
  793. else
  794. let name = declr.Ast.identifier in
  795. let len_var = mk_len_var name in
  796. let parm_name = mk_tmp_var name in
  797. if pattr.Ast.pa_chkptr
  798. then mk_check_ptr parm_name len_var
  799. else ""
  800. in
  801. let new_param_list = List.map conv_array_to_ptr plist
  802. in
  803. List.fold_left
  804. (fun acc (pty, declr) ->
  805. match pty with
  806. Ast.PTVal _ -> acc
  807. | Ast.PTPtr(ty, attr) -> acc ^ gen_check_ptr ty attr declr) "" new_param_list
  808. (* If a foreign type is a readonly pointer, we cast it to 'void*' for memcpy() and free() *)
  809. let mk_in_ptr_dst_name (rdonly: bool) (ptr_name: string) =
  810. if rdonly then "(void*)" ^ ptr_name
  811. else ptr_name
  812. (* Generate the code to handle function pointer parameter direction,
  813. * which is to be inserted before actually calling the trusted function.
  814. *)
  815. let gen_parm_ptr_direction_pre (plist: Ast.pdecl list) =
  816. let clone_in_ptr (ty: Ast.atype) (attr: Ast.ptr_attr) (declr: Ast.declarator) =
  817. let name = declr.Ast.identifier in
  818. let is_ary = (Ast.is_array declr || attr.Ast.pa_isary) in
  819. let in_ptr_name = mk_in_var name in
  820. let in_ptr_type = sprintf "%s%s" (Ast.get_tystr ty) (if is_ary then "*" else "") in
  821. let len_var = mk_len_var name in
  822. let in_ptr_dst_name = mk_in_ptr_dst_name attr.Ast.pa_rdonly in_ptr_name in
  823. let tmp_ptr_name= mk_tmp_var name in
  824. let check_sizefunc_ptr (fn: string) =
  825. sprintf "\t\t/* check whether the pointer is modified. */\n\
  826. \t\tif (%s(%s) != %s) {\n\
  827. \t\t\tstatus = SGX_ERROR_INVALID_PARAMETER;\n\
  828. \t\t\tgoto err;\n\
  829. \t\t}" fn in_ptr_name len_var
  830. in
  831. let malloc_and_copy pre_indent =
  832. match attr.Ast.pa_direction with
  833. Ast.PtrIn | Ast.PtrInOut ->
  834. let code_template = [
  835. sprintf "if (%s != NULL) {" tmp_ptr_name;
  836. sprintf "\t%s = (%s)malloc(%s);" in_ptr_name in_ptr_type len_var;
  837. sprintf "\tif (%s == NULL) {" in_ptr_name;
  838. "\t\tstatus = SGX_ERROR_OUT_OF_MEMORY;";
  839. "\t\tgoto err;";
  840. "\t}\n";
  841. sprintf "\tmemcpy(%s, %s, %s);" in_ptr_dst_name tmp_ptr_name len_var;
  842. ]
  843. in
  844. let s1 = List.fold_left (fun acc s -> acc ^ pre_indent ^ s ^ "\n") "" code_template in
  845. let s2 =
  846. if attr.Ast.pa_isstr
  847. then sprintf "%s\t\t%s[%s - 1] = '\\0';\n" s1 in_ptr_name len_var
  848. else if attr.Ast.pa_iswstr
  849. then sprintf "%s\t\t%s[(%s - sizeof(wchar_t))/sizeof(wchar_t)] = (wchar_t)0;\n" s1 in_ptr_name len_var
  850. else s1 in
  851. let s3 =
  852. match attr.Ast.pa_size.Ast.ps_sizefunc with
  853. None -> s2
  854. | Some s -> sprintf "%s\n%s\n" s2 (check_sizefunc_ptr(s))
  855. in sprintf "%s\t}\n" s3
  856. | Ast.PtrOut ->
  857. let code_template = [
  858. sprintf "if (%s != NULL) {" tmp_ptr_name;
  859. sprintf "\tif ((%s = (%s)malloc(%s)) == NULL) {" in_ptr_name in_ptr_type len_var;
  860. "\t\tstatus = SGX_ERROR_OUT_OF_MEMORY;";
  861. "\t\tgoto err;";
  862. "\t}\n";
  863. sprintf "\tmemset((void*)%s, 0, %s);" in_ptr_name len_var;
  864. "}"]
  865. in
  866. List.fold_left (fun acc s -> acc ^ pre_indent ^ s ^ "\n") "" code_template
  867. | _ -> ""
  868. in
  869. malloc_and_copy "\t"
  870. in List.fold_left
  871. (fun acc (pty, declr) ->
  872. match pty with
  873. Ast.PTVal _ -> acc
  874. | Ast.PTPtr (ty, attr) -> acc ^ clone_in_ptr ty attr declr) "" plist
  875. (* Generate the code to handle function pointer parameter direction,
  876. * which is to be inserted after finishing calling the trusted function.
  877. *)
  878. let gen_parm_ptr_direction_post (plist: Ast.pdecl list) =
  879. let copy_and_free (attr: Ast.ptr_attr) (declr: Ast.declarator) =
  880. let name = declr.Ast.identifier in
  881. let in_ptr_name = mk_in_var name in
  882. let len_var = mk_len_var name in
  883. let in_ptr_dst_name = mk_in_ptr_dst_name attr.Ast.pa_rdonly in_ptr_name in
  884. match attr.Ast.pa_direction with
  885. Ast.PtrIn -> sprintf "\tif (%s) free(%s);\n" in_ptr_name in_ptr_dst_name
  886. | Ast.PtrInOut | Ast.PtrOut ->
  887. sprintf "\tif (%s) {\n\t\tmemcpy(%s, %s, %s);\n\t\tfree(%s);\n\t}\n"
  888. in_ptr_name
  889. (mk_tmp_var name)
  890. in_ptr_name
  891. len_var
  892. in_ptr_name
  893. | _ -> ""
  894. in List.fold_left
  895. (fun acc (pty, declr) ->
  896. match pty with
  897. Ast.PTVal _ -> acc
  898. | Ast.PTPtr (ty, attr) -> acc ^ copy_and_free attr declr) "" plist
  899. (* Generate an "err:" goto mark if necessary. *)
  900. let gen_err_mark (plist: Ast.pdecl list) =
  901. let has_inout_p (attr: Ast.ptr_attr): bool =
  902. attr.Ast.pa_direction <> Ast.PtrNoDirection
  903. in
  904. if List.exists (fun (pt, name) ->
  905. match pt with
  906. Ast.PTVal _ -> false
  907. | Ast.PTPtr(_, attr) -> has_inout_p attr) plist
  908. then "err:"
  909. else ""
  910. (* It is used to save the parameters used as the value of size/count attribute. *)
  911. let param_cache = Hashtbl.create 1
  912. let is_in_param_cache s = Hashtbl.mem param_cache s
  913. (* Try to generate a temporary value to save the size of the buffer. *)
  914. let gen_tmp_size (pattr: Ast.ptr_attr) (plist: Ast.pdecl list) =
  915. let do_gen_temp_var (s: string) =
  916. if is_in_param_cache s then ""
  917. else
  918. let param_tystr = find_param_type s plist in
  919. let tmp_var = mk_tmp_var s in
  920. let parm_str = mk_parm_accessor s in
  921. Hashtbl.add param_cache s true;
  922. sprintf "\t%s %s = %s;\n" param_tystr tmp_var parm_str
  923. in
  924. let gen_temp_var (v: Ast.attr_value) =
  925. match v with
  926. Ast.ANumber _ -> ""
  927. | Ast.AString s -> do_gen_temp_var s
  928. in
  929. let tmp_size_str =
  930. match pattr.Ast.pa_size.Ast.ps_size with
  931. Some v -> gen_temp_var v
  932. | None -> ""
  933. in
  934. let tmp_count_str =
  935. match pattr.Ast.pa_size.Ast.ps_count with
  936. Some v -> gen_temp_var v
  937. | None -> ""
  938. in
  939. sprintf "%s%s" tmp_size_str tmp_count_str
  940. let is_ptr (pt: Ast.parameter_type) =
  941. match pt with
  942. Ast.PTVal _ -> false
  943. | Ast.PTPtr _ -> true
  944. let is_ptr_type (aty: Ast.atype) =
  945. match aty with
  946. Ast.Ptr _ -> true
  947. | _ -> false
  948. let ptr_has_direction (pt: Ast.parameter_type) =
  949. match pt with
  950. Ast.PTVal _ -> false
  951. | Ast.PTPtr(_, a) -> a.Ast.pa_direction <> Ast.PtrNoDirection
  952. let tbridge_mk_parm_name_ext (pt: Ast.parameter_type) (declr: Ast.declarator) =
  953. if is_in_param_cache declr.Ast.identifier || (is_ptr pt && (not (is_foreign_array pt)))
  954. then
  955. if ptr_has_direction pt
  956. then mk_in_var declr.Ast.identifier
  957. else mk_tmp_var declr.Ast.identifier
  958. else mk_parm_name_ext pt declr
  959. let mk_parm_name_tbridge (pt: Ast.parameter_type) (declr: Ast.declarator) =
  960. add_foreign_array_ptrref tbridge_mk_parm_name_ext pt declr
  961. (* Generate local variables required for the trusted bridge. *)
  962. let gen_tbridge_local_vars (plist: Ast.pdecl list) =
  963. let status_var = "\tsgx_status_t status = SGX_SUCCESS;\n" in
  964. let do_gen_local_var (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
  965. let tmp_var =
  966. (* Save a copy of pointer in case it might be modified in the marshaling structure. *)
  967. sprintf "\t%s %s = %s;\n" (Ast.get_tystr ty) (mk_tmp_var name) (mk_parm_accessor name)
  968. in
  969. let len_var =
  970. if not attr.Ast.pa_chkptr then ""
  971. else gen_tmp_size attr plist ^ "\t" ^ gen_ptr_size ty attr name mk_tmp_var in
  972. let in_ptr =
  973. match attr.Ast.pa_direction with
  974. Ast.PtrNoDirection -> ""
  975. | _ -> sprintf "\t%s %s = NULL;\n" (Ast.get_tystr ty) (mk_in_var name)
  976. in
  977. tmp_var ^ len_var ^ in_ptr
  978. in
  979. let gen_local_var_for_foreign_array (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
  980. let tystr = Ast.get_tystr ty in
  981. let tmp_var =
  982. sprintf "\t%s* %s = %s;\n" tystr (mk_tmp_var name) (mk_parm_accessor name)
  983. in
  984. let len_var = sprintf "\tsize_t %s = sizeof(%s);\n" (mk_len_var name) tystr
  985. in
  986. let in_ptr = sprintf "\t%s* %s = NULL;\n" tystr (mk_in_var name)
  987. in
  988. match attr.Ast.pa_direction with
  989. Ast.PtrNoDirection -> ""
  990. | _ -> tmp_var ^ len_var ^ in_ptr
  991. in
  992. let gen_local_var (pd: Ast.pdecl) =
  993. let (pty, declr) = pd in
  994. match pty with
  995. Ast.PTVal _ -> ""
  996. | Ast.PTPtr (ty, attr) ->
  997. if is_foreign_array pty
  998. then gen_local_var_for_foreign_array ty attr declr.Ast.identifier
  999. else do_gen_local_var ty attr declr.Ast.identifier
  1000. in
  1001. let new_param_list = List.map conv_array_to_ptr plist
  1002. in
  1003. Hashtbl.clear param_cache;
  1004. List.fold_left (fun acc pd -> acc ^ gen_local_var pd) status_var new_param_list
  1005. (* It generates trusted bridge code for a trusted function. *)
  1006. let gen_func_tbridge (fd: Ast.func_decl) (dummy_var: string) =
  1007. let func_open = sprintf "static sgx_status_t SGX_CDECL %s(void* %s)\n{\n"
  1008. (mk_tbridge_name fd.Ast.fname)
  1009. ms_ptr_name in
  1010. let local_vars = gen_tbridge_local_vars fd.Ast.plist in
  1011. let func_close = "\treturn status;\n}\n" in
  1012. let ms_struct_name = mk_ms_struct_name fd.Ast.fname in
  1013. let declare_ms_ptr = sprintf "%s* %s = SGX_CAST(%s*, %s);"
  1014. ms_struct_name
  1015. ms_struct_val
  1016. ms_struct_name
  1017. ms_ptr_name in
  1018. let invoke_func = gen_func_invoking fd mk_parm_name_tbridge in
  1019. let update_retval = sprintf "%s = %s"
  1020. (mk_parm_accessor retval_name)
  1021. invoke_func in
  1022. if is_naked_func fd then
  1023. let check_pms =
  1024. sprintf "if (%s != NULL) return SGX_ERROR_INVALID_PARAMETER;" ms_ptr_name
  1025. in
  1026. sprintf "%s%s%s\t%s\n\t%s\n%s" func_open local_vars dummy_var check_pms invoke_func func_close
  1027. else
  1028. sprintf "%s%s\t%s\n%s\n%s%s\n%s\t%s\n%s\n%s\n%s"
  1029. func_open
  1030. (mk_check_pms fd.Ast.fname)
  1031. declare_ms_ptr
  1032. local_vars
  1033. (gen_check_tbridge_length_overflow fd.Ast.plist)
  1034. (gen_check_tbridge_ptr_parms fd.Ast.plist)
  1035. (gen_parm_ptr_direction_pre fd.Ast.plist)
  1036. (if fd.Ast.rtype <> Ast.Void then update_retval else invoke_func)
  1037. (gen_err_mark fd.Ast.plist)
  1038. (gen_parm_ptr_direction_post fd.Ast.plist)
  1039. func_close
  1040. let tproxy_fill_ms_field (pd: Ast.pdecl) =
  1041. let (pt, declr) = pd in
  1042. let name = declr.Ast.identifier in
  1043. let len_var = mk_len_var name in
  1044. let parm_accessor = mk_parm_accessor name in
  1045. match pt with
  1046. Ast.PTVal _ -> fill_ms_field true pd
  1047. | Ast.PTPtr(ty, attr) ->
  1048. let is_ary = (Ast.is_array declr || attr.Ast.pa_isary) in
  1049. let tystr = sprintf "%s%s" (get_param_tystr pt) (if is_ary then "*" else "") in
  1050. if is_ary && is_ptr_type ty then
  1051. sprintf "\n#pragma message(\"Pointer array `%s' in trusted proxy `\"\
  1052. __FUNCTION__ \"' is dangerous. No code generated.\")\n" name
  1053. else
  1054. let in_ptr_dst_name = mk_in_ptr_dst_name attr.Ast.pa_rdonly parm_accessor in
  1055. if not attr.Ast.pa_chkptr (* [user_check] specified *)
  1056. then sprintf "%s = SGX_CAST(%s, %s);" parm_accessor tystr name
  1057. else
  1058. match attr.Ast.pa_direction with
  1059. Ast.PtrOut ->
  1060. let code_template =
  1061. [sprintf "if (%s != NULL && sgx_is_within_enclave(%s, %s)) {" name name len_var;
  1062. sprintf "\t%s = (%s)__tmp;" parm_accessor tystr;
  1063. sprintf "\t__tmp = (void *)((size_t)__tmp + %s);" len_var;
  1064. sprintf "\tmemset(%s, 0, %s);" in_ptr_dst_name len_var;
  1065. sprintf "} else if (%s == NULL) {" name;
  1066. sprintf "\t%s = NULL;" parm_accessor;
  1067. "} else {";
  1068. "\tsgx_ocfree();";
  1069. "\treturn SGX_ERROR_INVALID_PARAMETER;";
  1070. "}"
  1071. ]
  1072. in List.fold_left (fun acc s -> acc ^ s ^ "\n\t") "" code_template
  1073. | _ ->
  1074. let code_template =
  1075. [sprintf "if (%s != NULL && sgx_is_within_enclave(%s, %s)) {" name name len_var;
  1076. sprintf "\t%s = (%s)__tmp;" parm_accessor tystr;
  1077. sprintf "\t__tmp = (void *)((size_t)__tmp + %s);" len_var;
  1078. sprintf "\tmemcpy(%s, %s, %s);" in_ptr_dst_name name len_var;
  1079. sprintf "} else if (%s == NULL) {" name;
  1080. sprintf "\t%s = NULL;" parm_accessor;
  1081. "} else {";
  1082. "\tsgx_ocfree();";
  1083. "\treturn SGX_ERROR_INVALID_PARAMETER;";
  1084. "}"
  1085. ]
  1086. in List.fold_left (fun acc s -> acc ^ s ^ "\n\t") "" code_template
  1087. (* Generate local variables required for the trusted proxy. *)
  1088. let gen_tproxy_local_vars (plist: Ast.pdecl list) =
  1089. let status_var = "sgx_status_t status = SGX_SUCCESS;\n" in
  1090. let do_gen_local_var (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
  1091. if not attr.Ast.pa_chkptr then ""
  1092. else "\t" ^ gen_ptr_size ty attr name (fun x -> x)
  1093. in
  1094. let gen_local_var (pd: Ast.pdecl) =
  1095. let (pty, declr) = pd in
  1096. match pty with
  1097. Ast.PTVal _ -> ""
  1098. | Ast.PTPtr (ty, attr) -> do_gen_local_var ty attr declr.Ast.identifier
  1099. in
  1100. let new_param_list = List.map conv_array_to_ptr plist
  1101. in
  1102. List.fold_left (fun acc pd -> acc ^ gen_local_var pd) status_var new_param_list
  1103. (* Generate only one ocalloc block required for the trusted proxy. *)
  1104. let gen_ocalloc_block (fname: string) (plist: Ast.pdecl list) =
  1105. let ms_struct_name = mk_ms_struct_name fname in
  1106. let local_vars_block = sprintf "%s* %s = NULL;\n\tsize_t ocalloc_size = sizeof(%s);\n\tvoid *__tmp = NULL;\n\n" ms_struct_name ms_struct_val ms_struct_name in
  1107. let count_ocalloc_size (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
  1108. if not attr.Ast.pa_chkptr then ""
  1109. else sprintf "\tocalloc_size += (%s != NULL && sgx_is_within_enclave(%s, %s)) ? %s : 0;\n" name name (mk_len_var name) (mk_len_var name)
  1110. in
  1111. let do_count_ocalloc_size (pd: Ast.pdecl) =
  1112. let (pty, declr) = pd in
  1113. match pty with
  1114. Ast.PTVal _ -> ""
  1115. | Ast.PTPtr (ty, attr) -> count_ocalloc_size ty attr declr.Ast.identifier
  1116. in
  1117. let do_gen_ocalloc_block = [
  1118. "\n\t__tmp = sgx_ocalloc(ocalloc_size);\n";
  1119. "\tif (__tmp == NULL) {\n";
  1120. "\t\tsgx_ocfree();\n";
  1121. "\t\treturn SGX_ERROR_UNEXPECTED;\n";
  1122. "\t}\n";
  1123. sprintf "\t%s = (%s*)__tmp;\n" ms_struct_val ms_struct_name;
  1124. sprintf "\t__tmp = (void *)((size_t)__tmp + sizeof(%s));\n" ms_struct_name;
  1125. ]
  1126. in
  1127. let new_param_list = List.map conv_array_to_ptr plist
  1128. in
  1129. let s1 = List.fold_left (fun acc pd -> acc ^ do_count_ocalloc_size pd) local_vars_block new_param_list in
  1130. List.fold_left (fun acc s -> acc ^ s) s1 do_gen_ocalloc_block
  1131. (* Generate trusted proxy code for a given untrusted function. *)
  1132. let gen_func_tproxy (ufunc: Ast.untrusted_func) (idx: int) =
  1133. let fd = ufunc.Ast.uf_fdecl in
  1134. let propagate_errno = ufunc.Ast.uf_propagate_errno in
  1135. let func_open = sprintf "%s\n{\n" (gen_tproxy_proto fd) in
  1136. let local_vars = gen_tproxy_local_vars fd.Ast.plist in
  1137. let ocalloc_ms_struct = gen_ocalloc_block fd.Ast.fname fd.Ast.plist in
  1138. let gen_ocfree rtype plist =
  1139. if rtype = Ast.Void && plist = [] then "" else "\tsgx_ocfree();\n"
  1140. in
  1141. let handle_out_ptr plist =
  1142. let copy_memory (attr: Ast.ptr_attr) (declr: Ast.declarator) =
  1143. let name = declr.Ast.identifier in
  1144. match attr.Ast.pa_direction with
  1145. Ast.PtrInOut | Ast.PtrOut ->
  1146. sprintf "\tif (%s) memcpy((void*)%s, %s, %s);\n" name name (mk_parm_accessor name) (mk_len_var name)
  1147. | _ -> ""
  1148. in List.fold_left (fun acc (pty, declr) ->
  1149. match pty with
  1150. Ast.PTVal _ -> acc
  1151. | Ast.PTPtr(ty, attr) -> acc ^ copy_memory attr declr) "" plist in
  1152. let set_errno = if propagate_errno then "\terrno = ms->ocall_errno;" else "" in
  1153. let func_close = sprintf "%s%s\n%s%s\n"
  1154. (handle_out_ptr fd.Ast.plist)
  1155. set_errno
  1156. (gen_ocfree fd.Ast.rtype fd.Ast.plist)
  1157. "\treturn status;\n}" in
  1158. let ocall_null = sprintf "status = sgx_ocall(%d, NULL);\n" idx in
  1159. let ocall_with_ms = sprintf "status = sgx_ocall(%d, %s);\n"
  1160. idx ms_struct_val in
  1161. let update_retval = sprintf "if (%s) *%s = %s;"
  1162. retval_name retval_name (mk_parm_accessor retval_name) in
  1163. let func_body = ref [] in
  1164. if (is_naked_func fd) && (propagate_errno = false) then
  1165. sprintf "%s\t%s\t%s%s" func_open local_vars ocall_null func_close
  1166. else
  1167. begin
  1168. func_body := local_vars :: !func_body;
  1169. func_body := ocalloc_ms_struct:: !func_body;
  1170. List.iter (fun pd -> func_body := tproxy_fill_ms_field pd :: !func_body) fd.Ast.plist;
  1171. func_body := ocall_with_ms :: !func_body;
  1172. if fd.Ast.rtype <> Ast.Void then func_body := update_retval :: !func_body;
  1173. List.fold_left (fun acc s -> acc ^ "\t" ^ s ^ "\n") func_open (List.rev !func_body) ^ func_close
  1174. end
  1175. (* It generates OCALL table and the untrusted proxy to setup OCALL table. *)
  1176. let gen_ocall_table (ec: enclave_content) =
  1177. let func_proto_ubridge = List.map (fun (uf: Ast.untrusted_func) ->
  1178. let fd : Ast.func_decl = uf.Ast.uf_fdecl in
  1179. mk_ubridge_name ec.file_shortnm fd.Ast.fname)
  1180. ec.ufunc_decls in
  1181. let nr_ocall = List.length ec.ufunc_decls in
  1182. let ocall_table_name = mk_ocall_table_name ec.enclave_name in
  1183. let ocall_table =
  1184. let ocall_members =
  1185. List.fold_left
  1186. (fun acc proto -> acc ^ "\t\t(void*)" ^ proto ^ ",\n") "" func_proto_ubridge
  1187. in "\t{\n" ^ ocall_members ^ "\t}\n"
  1188. in
  1189. sprintf "static const struct {\n\
  1190. \tsize_t nr_ocall;\n\
  1191. \tvoid * table[%d];\n\
  1192. } %s = {\n\
  1193. \t%d,\n\
  1194. %s};\n" (max nr_ocall 1)
  1195. ocall_table_name
  1196. nr_ocall
  1197. (if nr_ocall <> 0 then ocall_table else "\t{ NULL },\n")
  1198. (* It generates untrusted code to be saved in a `.c' file. *)
  1199. let gen_untrusted_source (ec: enclave_content) =
  1200. let code_fname = get_usource_name ec.file_shortnm in
  1201. let include_hd = "#include \"" ^ get_uheader_short_name ec.file_shortnm ^ "\"\n" in
  1202. let include_errno = "#include <errno.h>\n" in
  1203. let uproxy_list =
  1204. List.map2 (fun fd ecall_idx -> gen_func_uproxy fd ecall_idx ec)
  1205. (tf_list_to_fd_list ec.tfunc_decls)
  1206. (Util.mk_seq 0 (List.length ec.tfunc_decls - 1))
  1207. in
  1208. let ubridge_list =
  1209. List.map (fun fd -> gen_func_ubridge ec.file_shortnm fd)
  1210. (ec.ufunc_decls) in
  1211. let out_chan = open_out code_fname in
  1212. output_string out_chan (include_hd ^ include_errno ^ "\n");
  1213. ms_writer out_chan ec;
  1214. List.iter (fun s -> output_string out_chan (s ^ "\n")) ubridge_list;
  1215. output_string out_chan (gen_ocall_table ec);
  1216. List.iter (fun s -> output_string out_chan (s ^ "\n")) uproxy_list;
  1217. close_out out_chan
  1218. (* It generates trusted code to be saved in a `.c' file. *)
  1219. let gen_trusted_source (ec: enclave_content) =
  1220. let code_fname = get_tsource_name ec.file_shortnm in
  1221. let include_hd = "#include \"" ^ get_theader_short_name ec.file_shortnm ^ "\"\n\n\
  1222. #include \"sgx_trts.h\" /* for sgx_ocalloc, sgx_is_outside_enclave */\n\n\
  1223. #include <errno.h>\n\
  1224. #include <string.h> /* for memcpy etc */\n\
  1225. #include <stdlib.h> /* for malloc/free etc */\n\
  1226. \n\
  1227. #define CHECK_REF_POINTER(ptr, siz) do {\t\\\n\
  1228. \tif (!(ptr) || ! sgx_is_outside_enclave((ptr), (siz)))\t\\\n\
  1229. \t\treturn SGX_ERROR_INVALID_PARAMETER;\\\n\
  1230. } while (0)\n\
  1231. \n\
  1232. #define CHECK_UNIQUE_POINTER(ptr, siz) do {\t\\\n\
  1233. \tif ((ptr) && ! sgx_is_outside_enclave((ptr), (siz)))\t\\\n\
  1234. \t\treturn SGX_ERROR_INVALID_PARAMETER;\\\n\
  1235. } while (0)\n\
  1236. \n" in
  1237. let trusted_fds = tf_list_to_fd_list ec.tfunc_decls in
  1238. let tbridge_list =
  1239. let dummy_var = tbridge_gen_dummy_variable ec in
  1240. List.map (fun tfd -> gen_func_tbridge tfd dummy_var) trusted_fds in
  1241. let ecall_table = gen_ecall_table ec.tfunc_decls in
  1242. let entry_table = gen_entry_table ec in
  1243. let tproxy_list = List.map2
  1244. (fun fd idx -> gen_func_tproxy fd idx)
  1245. (ec.ufunc_decls)
  1246. (Util.mk_seq 0 (List.length ec.ufunc_decls - 1)) in
  1247. let out_chan = open_out code_fname in
  1248. output_string out_chan (include_hd ^ "\n");
  1249. ms_writer out_chan ec;
  1250. List.iter (fun s -> output_string out_chan (s ^ "\n")) tbridge_list;
  1251. output_string out_chan (ecall_table ^ "\n");
  1252. output_string out_chan (entry_table ^ "\n");
  1253. output_string out_chan "\n";
  1254. List.iter (fun s -> output_string out_chan (s ^ "\n")) tproxy_list;
  1255. close_out out_chan
  1256. (* We use a stack to keep record of imported files.
  1257. *
  1258. * A file will be pushed to the stack before we parsing it,
  1259. * and we will pop the stack after each `parse_import_file'.
  1260. *)
  1261. let already_read = SimpleStack.create ()
  1262. let save_file fullpath =
  1263. if SimpleStack.mem fullpath already_read
  1264. then failwithf "detected circled import for `%s'" fullpath
  1265. else SimpleStack.push fullpath already_read
  1266. (* The entry point of the Edger8r parser front-end.
  1267. * ------------------------------------------------
  1268. *)
  1269. let start_parsing (fname: string) : Ast.enclave =
  1270. let set_initial_pos lexbuf filename =
  1271. lexbuf.Lexing.lex_curr_p <- {
  1272. lexbuf.Lexing.lex_curr_p with Lexing.pos_fname = fname;
  1273. }
  1274. in
  1275. try
  1276. let fullpath = Util.get_file_path fname in
  1277. let preprocessed =
  1278. save_file fullpath; Preprocessor.processor_macro(fullpath) in
  1279. let lexbuf =
  1280. match preprocessed with
  1281. | None ->
  1282. let chan = open_in fullpath in
  1283. Lexing.from_channel chan
  1284. | Some(preprocessed_string) -> Lexing.from_string preprocessed_string
  1285. in
  1286. try
  1287. set_initial_pos lexbuf fname;
  1288. let e : Ast.enclave = Parser.start_parsing Lexer.tokenize lexbuf in
  1289. let short_name = Util.get_short_name fname in
  1290. if short_name = ""
  1291. then (eprintf "error: %s: file short name is empty\n" fname; exit 1;)
  1292. else
  1293. let res = { e with Ast.ename = short_name } in
  1294. if Util.is_c_identifier short_name then res
  1295. else (eprintf "warning: %s: file short name `%s' is not a valid C identifier\n" fname short_name; res)
  1296. with exn ->
  1297. begin match exn with
  1298. | Parsing.Parse_error ->
  1299. let curr = lexbuf.Lexing.lex_curr_p in
  1300. let line = curr.Lexing.pos_lnum in
  1301. let cnum = curr.Lexing.pos_cnum - curr.Lexing.pos_bol in
  1302. let tok = Lexing.lexeme lexbuf in
  1303. failwithf "%s:%d:%d: unexpected token: %s\n" fname line cnum tok
  1304. | _ -> raise exn
  1305. end
  1306. with Sys_error s -> failwithf "%s\n" s
  1307. (* Check duplicated ECALL/OCALL names.
  1308. *
  1309. * This is a pretty simple implementation - to improve it, the
  1310. * location information of each token should be carried to AST.
  1311. *)
  1312. let check_duplication (ec: enclave_content) =
  1313. let dict = Hashtbl.create 10 in
  1314. let trusted_fds = tf_list_to_fd_list ec.tfunc_decls in
  1315. let untrusted_fds = uf_list_to_fd_list ec.ufunc_decls in
  1316. let check_and_add fname =
  1317. if Hashtbl.mem dict fname then
  1318. failwithf "Multiple definition of function \"%s\" detected." fname
  1319. else
  1320. Hashtbl.add dict fname true
  1321. in
  1322. List.iter (fun (fd: Ast.func_decl) ->
  1323. check_and_add fd.Ast.fname) (trusted_fds @ untrusted_fds)
  1324. (* For each untrusted functions, check that allowed ECALL does exist. *)
  1325. let check_allow_list (ec: enclave_content) =
  1326. let trusted_func_names = get_trusted_func_names ec in
  1327. let do_check_allow_list fname allowed_ecalls =
  1328. List.iter (fun trusted_func ->
  1329. if List.exists (fun x -> x = trusted_func) trusted_func_names
  1330. then ()
  1331. else
  1332. failwithf "\"%s\" declared to allow unknown function \"%s\"."
  1333. fname trusted_func) allowed_ecalls
  1334. in
  1335. List.iter (fun (uf: Ast.untrusted_func) ->
  1336. let fd = uf.Ast.uf_fdecl in
  1337. let allowed_ecalls = uf.Ast.uf_allow_list in
  1338. do_check_allow_list fd.Ast.fname allowed_ecalls) ec.ufunc_decls
  1339. (* Report private ECALL not used in any "allow(...)" expression. *)
  1340. let report_orphaned_priv_ecall (ec: enclave_content) =
  1341. let priv_ecall_names = get_priv_ecall_names ec.tfunc_decls in
  1342. let allowed_names = get_allowed_names ec.ufunc_decls in
  1343. let check_ecall n = if List.mem n allowed_names then ()
  1344. else eprintf "warning: private ECALL `%s' is not used by any OCALL\n" n
  1345. in
  1346. List.iter check_ecall priv_ecall_names
  1347. (* Check that there is at least one public ECALL function. *)
  1348. let check_priv_funcs (ec: enclave_content) =
  1349. let priv_bits = tf_list_to_priv_list ec.tfunc_decls in
  1350. if List.for_all (fun is_priv -> is_priv) priv_bits
  1351. then failwithf "the enclave `%s' contains no public root ECALL.\n" ec.file_shortnm
  1352. else report_orphaned_priv_ecall ec
  1353. (* When generating edge-routines, it need first to check whether there
  1354. * are `import' expressions inside EDL. If so, it will parse the given
  1355. * importing file to get an `enclave_content' record, recursively.
  1356. *
  1357. * `ec' is the toplevel `enclave_content' record.
  1358. * Here, a tree reduce algorithm is used. `ec' is the root-node, each
  1359. * `import' expression is considerred as a children.
  1360. *)
  1361. let reduce_import (ec: enclave_content) =
  1362. let combine (ec1: enclave_content) (ec2: enclave_content) =
  1363. { ec1 with
  1364. include_list = ec1.include_list @ ec2.include_list;
  1365. import_exprs = [];
  1366. comp_defs = ec1.comp_defs @ ec2.comp_defs;
  1367. tfunc_decls = ec1.tfunc_decls @ ec2.tfunc_decls;
  1368. ufunc_decls = ec1.ufunc_decls @ ec2.ufunc_decls; }
  1369. in
  1370. let parse_import_file fname =
  1371. let ec = parse_enclave_ast (start_parsing fname)
  1372. in
  1373. match ec.import_exprs with
  1374. [] -> (SimpleStack.pop already_read |> ignore; ec )
  1375. | _ -> ec
  1376. in
  1377. let check_funs funcs (ec: enclave_content) =
  1378. (* Check whether `funcs' are listed in `ec'. It returns a
  1379. production (x, y), where:
  1380. x - functions not listed in `ec';
  1381. y - a new `ec' that contains functions from `funcs' listed in `ec'.
  1382. *)
  1383. let enclave_funcs =
  1384. let trusted_func_names = get_trusted_func_names ec in
  1385. let untrusted_func_names = get_untrusted_func_names ec in
  1386. trusted_func_names @ untrusted_func_names
  1387. in
  1388. let in_ec_def name = List.exists (fun x -> x = name) enclave_funcs in
  1389. let in_import_list name = List.exists (fun x -> x = name) funcs in
  1390. let x = List.filter (fun name -> not (in_ec_def name)) funcs in
  1391. let y =
  1392. { empty_ec with
  1393. tfunc_decls = List.filter (fun tf ->
  1394. in_import_list (get_tf_fname tf)) ec.tfunc_decls;
  1395. ufunc_decls = List.filter (fun uf ->
  1396. in_import_list (get_uf_fname uf)) ec.ufunc_decls; }
  1397. in (x, y)
  1398. in
  1399. (* Import functions listed in `funcs' from `importee'. *)
  1400. let rec import_funcs (funcs: string list) (importee: enclave_content) =
  1401. (* A `*' means importing all the functions. *)
  1402. if List.exists (fun x -> x = "*") funcs
  1403. then
  1404. List.fold_left (fun acc (ipd: Ast.import_decl) ->
  1405. let next_ec = parse_import_file ipd.Ast.mname
  1406. in combine acc (import_funcs ipd.Ast.flist next_ec)) importee importee.import_exprs
  1407. else
  1408. let (x, y) = check_funs funcs importee
  1409. in
  1410. if x = [] then y (* Resolved all importings *)
  1411. else
  1412. match importee.import_exprs with
  1413. [] -> failwithf "import failed - functions `%s' not found" (List.hd x)
  1414. | ex -> List.fold_left (fun acc (ipd: Ast.import_decl) ->
  1415. let next_ec = parse_import_file ipd.Ast.mname
  1416. in combine acc (import_funcs x next_ec)) y ex
  1417. in
  1418. import_funcs ["*"] ec
  1419. (* Generate the Enclave code. *)
  1420. let gen_enclave_code (e: Ast.enclave) (ep: edger8r_params) =
  1421. let ec = reduce_import (parse_enclave_ast e) in
  1422. g_use_prefix := ep.use_prefix;
  1423. g_untrusted_dir := ep.untrusted_dir;
  1424. g_trusted_dir := ep.trusted_dir;
  1425. create_dir ep.untrusted_dir;
  1426. create_dir ep.trusted_dir;
  1427. check_duplication ec;
  1428. check_allow_list ec;
  1429. (if not ep.header_only then check_priv_funcs ec);
  1430. (if ep.gen_untrusted then (gen_untrusted_header ec; if not ep.header_only then gen_untrusted_source ec));
  1431. (if ep.gen_trusted then (gen_trusted_header ec; if not ep.header_only then gen_trusted_source ec))