You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

302 lines
9.8 KiB

6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. import configparser
  2. import os
  3. import argparse
  4. import pymongo
  5. import ssl
  6. import mysql.connector
  7. import networkx as nx
  8. from . import queries
  9. from . import graph
  10. import minorminer
  11. from tqdm import tqdm
  12. import numpy as np
  13. def readConfig(configFilePath):
  14. config = configparser.ConfigParser()
  15. if os.path.isfile(configFilePath):
  16. config.read(configFilePath)
  17. return config
  18. class ArgParser:
  19. def __init__(self):
  20. self.__flags = {}
  21. self.__parser = argparse.ArgumentParser()
  22. self.__instanceDirArgSet = False
  23. self.__config = None
  24. self.__parsedArgs = {}
  25. def addArg(self, alias,
  26. shortFlag,
  27. longFlag,
  28. help,
  29. type,
  30. default=None,
  31. ignoreDatabaseConfig=False):
  32. self.__flags[alias] = {"longFlag": longFlag,
  33. "hasDefault": False,
  34. "ignoreDatabaseConfig": ignoreDatabaseConfig,
  35. "type": type}
  36. if default != None:
  37. self.__flags[alias]["hasDefault"] = True
  38. self.__parser.add_argument("-%s" % shortFlag,
  39. "--%s" % longFlag,
  40. help=help,
  41. type=type,
  42. default=default)
  43. def addInstanceDirArg(self):
  44. self.__instanceDirArgSet = True
  45. self.addArg(alias="datasetDir", shortFlag="d", longFlag="dataset_dir",
  46. help="the base direcotry of the dataset; if this flag is given the others can be omitted",
  47. type=str, ignoreDatabaseConfig=True)
  48. def parse(self):
  49. self.__parsedArgs = {}
  50. args = vars(self.__parser.parse_args())
  51. if self.__instanceDirArgSet:
  52. self.__config = readConfig(os.path.join(args["dataset_dir"],
  53. "dataset.config"))
  54. self.__parseDatasetConfig()
  55. for alias, flag in self.__flags.items():
  56. self.__parsedArgs[alias] = self.__processFlag(args, flag)
  57. self.__config = None
  58. return self.__parsedArgs
  59. def __parseDatasetConfig(self):
  60. for flag, value in self.__config["STRUCTURE"].items():
  61. self.__parsedArgs[flag] = value
  62. def __processFlag(self, args, flag):
  63. longFlag = flag["longFlag"]
  64. tmpValue = self.__parsedArgs[longFlag] if longFlag in self.__parsedArgs else None
  65. if flag["ignoreDatabaseConfig"] == True:
  66. tmpValue = None
  67. if args[longFlag]:
  68. tmpValue = args[longFlag]
  69. if tmpValue == None:
  70. tmpValue = flag["type"](input("pass arguement %s: " % longFlag))
  71. return tmpValue
  72. def getDBContext(dbConfigPath):
  73. dbContext = {}
  74. dbContext["client"] = connect_to_instance_pool(dbConfigPath)
  75. dbContext["db"] = dbContext["client"]["experiments"]
  76. dbContext["instances"] = dbContext["db"]["instances"]
  77. dbContext["experimentScopes"] = dbContext["db"]["experiment_scopes"]
  78. return dbContext
  79. def connect_to_instance_pool(dbConfigPath = "database.config"):
  80. dbConf = readConfig(dbConfigPath)
  81. client = pymongo.MongoClient(
  82. "mongodb://%s:%s@%s:%s/%s"
  83. % ( dbConf["INSTANCE_POOL"]["user"],
  84. dbConf["INSTANCE_POOL"]["pw"],
  85. dbConf["INSTANCE_POOL"]["url"],
  86. dbConf["INSTANCE_POOL"]["port"],
  87. dbConf["INSTANCE_POOL"]["database"]),
  88. ssl=True,
  89. ssl_cert_reqs=ssl.CERT_NONE)
  90. return client[dbConf["INSTANCE_POOL"]["database"]]
  91. def connect_to_experimetns_db(dbConfigPath = "database.config"):
  92. dbConfig = readConfig(dbConfigPath)
  93. return mysql.connector.connect(
  94. host=dbConfig["EXPERIMENT_DB"]["url"],
  95. port=dbConfig["EXPERIMENT_DB"]["port"],
  96. user=dbConfig["EXPERIMENT_DB"]["user"],
  97. password=dbConfig["EXPERIMENT_DB"]["pw"],
  98. database=dbConfig["EXPERIMENT_DB"]["database"]
  99. )
  100. def frange(start, stop, steps):
  101. while start < stop:
  102. yield start
  103. start += steps
  104. def create_experiment_scope(db, description, name):
  105. experimentScope = {}
  106. experimentScope["instances"] = []
  107. experimentScope["description"] = description
  108. experimentScope["_id"] = name.strip()
  109. db["experiment_scopes"].insert_one(experimentScope)
  110. def write_instance_to_pool_db(db, instance):
  111. instance_document = instance.writeJSONLike()
  112. result = db["instances"].insert_one(instance_document)
  113. return result.inserted_id
  114. def add_instance_to_experiment_scope(db, scope_name, instance_id):
  115. db["experiment_scopes"].update_one(
  116. {"_id": scope_name},
  117. {"$push": {"instances": instance_id}}
  118. )
  119. def write_qubo_to_pool_db(collection, qubo, sat_instance_id):
  120. doc = {}
  121. doc["instance"] = sat_instance_id
  122. doc["description"] = {"<qubo>": "<entrys>",
  123. "<entrys>": "<entry><entrys> | <entry> | \"\"",
  124. "<entry>": "<coupler><energy>",
  125. "<energy>": "<real_number>",
  126. "<coupler>": "<node><node>",
  127. "<node>": "<clause><literal>",
  128. "<clause>": "<natural_number>",
  129. "<literal>": "<integer>"}
  130. doc["qubo"] = __qubo_to_JSON(qubo)
  131. collection.insert_one(doc)
  132. def __qubo_to_JSON(qubo):
  133. quboJSON = []
  134. for coupler, value in qubo.items():
  135. quboJSON.append([coupler, float(value)])
  136. return quboJSON
  137. def write_wmis_embedding_to_pool_db(collection, qubo_id, solver_graph_id, embedding):
  138. if not __embedding_entry_exists(collection, qubo_id, solver_graph_id):
  139. __prepare_new_wmis_embedding_entry(collection, qubo_id, solver_graph_id)
  140. collection.update_one(
  141. {"qubo": qubo_id, "solver_graph": solver_graph_id},
  142. {"$push": {"embeddings": __embedding_to_array(embedding)}}
  143. )
  144. def __embedding_entry_exists(collection, qubo_id, solver_graph_id):
  145. filter = {"qubo": qubo_id, "solver_graph": solver_graph_id}
  146. if collection.count_documents(filter) > 0:
  147. return True
  148. return False
  149. def __prepare_new_wmis_embedding_entry(collection, qubo_id, solver_graph_id):
  150. doc = {}
  151. doc["qubo"] = qubo_id
  152. doc["solver_graph"] = solver_graph_id
  153. doc["description"] = {"<embedding>": "<chains>",
  154. "<chains>": "<chain><chains> | \"\"",
  155. "<chain>" : "<original_node><chimera_nodes>",
  156. "<chimera_nodes>": "<chimera_node><chimera_nodes> | \"\""}
  157. doc["embeddings"] = []
  158. collection.insert_one(doc)
  159. def __embedding_to_array(embedding):
  160. emb_arr = []
  161. for node, chain in embedding.items():
  162. emb_arr.append([node, chain])
  163. return emb_arr
  164. def write_solver_graph_to_pool_db(collection, graph):
  165. data = nx.node_link_data(graph)
  166. id = queries.get_id_of_solver_graph(collection, data)
  167. if id != None:
  168. return id
  169. doc = {}
  170. doc["data"] = data
  171. return collection.insert_one(doc).inserted_id
  172. def find_wmis_embeddings_for_scope(db, scope, solver_graph):
  173. solver_graph_id = write_solver_graph_to_pool_db(db["solver_graphs"],
  174. solver_graph)
  175. qubos = queries.WMIS_scope_query(db)
  176. qubos.query(scope)
  177. new_embeddings_found = 0
  178. already_found = 0
  179. total_count = 0
  180. for qubo, qubo_id in tqdm(qubos):
  181. total_count += 1
  182. max_no_improvement = 10
  183. for i in range(5):
  184. if __embedding_entry_exists(db["embeddings"], qubo_id, solver_graph_id):
  185. already_found += 1
  186. break;
  187. else:
  188. nx_qubo = graph.qubo_to_nx_graph(qubo)
  189. emb = minorminer.find_embedding(nx_qubo.edges(),
  190. solver_graph.edges(),
  191. return_overlap=True,
  192. max_no_improvement=max_no_improvement)
  193. if emb[1] == 1:
  194. write_wmis_embedding_to_pool_db(db["embeddings"],
  195. qubo_id,
  196. solver_graph_id,
  197. emb[0])
  198. new_embeddings_found += 1
  199. max_no_improvement *= 1.5
  200. percentage = 0
  201. if total_count > 0:
  202. percentage = round(((new_embeddings_found + already_found) / total_count) * 100)
  203. print("found {} of {} embeddigns ({}%)".format(new_embeddings_found + already_found,
  204. total_count,
  205. percentage))
  206. print("{} new embeddigns found".format(new_embeddings_found))
  207. def save_simulated_annealing_result(collection, result, solver_input, emb_list_index):
  208. doc = {}
  209. doc["data"] = result.to_serializable()
  210. doc["instance"] = solver_input["instance_id"]
  211. doc["embedding"] = {
  212. "embedding_id": solver_input["embeddings_id"],
  213. "list_index": emb_list_index
  214. }
  215. collection.insert_one(doc)
  216. def analyze_wmis_sample(sample):
  217. data = {}
  218. data["number_of_assignments"] = np.count_nonzero(list(sample.sample.values()))
  219. data["chain_break_fraction"] = sample.chain_break_fraction
  220. data["num_occurrences"] = sample.num_occurrences
  221. data["energy"] = sample.energy
  222. return data