data.h
2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/*===--------------------------------------------------------------------------
* ATMI (Asynchronous Task and Memory Interface)
*
* This file is distributed under the MIT License. See LICENSE.txt for details.
*===------------------------------------------------------------------------*/
#ifndef SRC_RUNTIME_INCLUDE_DATA_H_
#define SRC_RUNTIME_INCLUDE_DATA_H_
#include "atmi.h"
#include <hsa.h>
#include <map>
#include <mutex>
#include <stdio.h>
#include <stdlib.h>
// we maintain our own mapping of device addr to a user specified data object
// in order to work around a (possibly historic) bug in ROCr's
// hsa_amd_pointer_info_set_userdata for variable symbols
// this is expected to be temporary
namespace core {
// Internal representation of any data that is created and managed by ATMI.
// Data can be located on any device memory or host memory.
class ATLData {
public:
ATLData(void *ptr, size_t size, atmi_mem_place_t place)
: ptr_(ptr), size_(size), place_(place) {}
void *ptr() const { return ptr_; }
size_t size() const { return size_; }
atmi_mem_place_t place() const { return place_; }
private:
void *ptr_;
size_t size_;
atmi_mem_place_t place_;
};
//---
struct ATLMemoryRange {
const void *base_pointer;
const void *end_pointer;
ATLMemoryRange(const void *bp, size_t size_bytes)
: base_pointer(bp),
end_pointer(reinterpret_cast<const unsigned char *>(bp) + size_bytes -
1) {}
};
// Functor to compare ranges:
struct ATLMemoryRangeCompare {
// Return true is LHS range is less than RHS - used to order the ranges
bool operator()(const ATLMemoryRange &lhs, const ATLMemoryRange &rhs) const {
return lhs.end_pointer < rhs.base_pointer;
}
};
//-------------------------------------------------------------------------------------------------
// This structure tracks information for each pointer.
// Uses memory-range-based lookups - so pointers that exist anywhere in the
// range of hostPtr + size
// will find the associated ATLPointerInfo.
// The insertions and lookups use a self-balancing binary tree and should
// support O(logN) lookup speed.
// The structure is thread-safe - writers obtain a mutex before modifying the
// tree. Multiple simulatenous readers are supported.
class ATLPointerTracker {
typedef std::map<ATLMemoryRange, ATLData *, ATLMemoryRangeCompare>
MapTrackerType;
public:
void insert(void *pointer, ATLData *data);
void remove(void *pointer);
ATLData *find(const void *pointer);
private:
MapTrackerType tracker_;
std::mutex mutex_;
};
extern ATLPointerTracker g_data_map; // Track all am pointer allocations.
enum class Direction { ATMI_H2D, ATMI_D2H, ATMI_D2D, ATMI_H2H };
} // namespace core
#endif // SRC_RUNTIME_INCLUDE_DATA_H_